Support google-cloud-automl >=2.1.0 (#13505)

This commit is contained in:
Kamil Breguła 2021-01-11 09:39:44 +01:00 коммит произвёл GitHub
Родитель 947dbb73bb
Коммит a6f999b62e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 134 добавлений и 113 удалений

Просмотреть файл

@ -29,6 +29,7 @@ Details are covered in the UPDATING.md files for each library, but there are som
| Library name | Previous constraints | Current constraints | |
| --- | --- | --- | --- |
| [``google-cloud-automl``](https://pypi.org/project/google-cloud-automl/) | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-automl/blob/master/UPGRADING.md) |
| [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) |
| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
| [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |

Просмотреть файл

@ -47,7 +47,7 @@ GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1")
GCP_AUTOML_DATASET_BUCKET = os.environ.get(
"GCP_AUTOML_DATASET_BUCKET", "gs://cloud-ml-tables-data/bank-marketing.csv"
)
TARGET = os.environ.get("GCP_AUTOML_TARGET", "Class")
TARGET = os.environ.get("GCP_AUTOML_TARGET", "Deposit")
# Example values
MODEL_ID = "TBL123456"
@ -76,9 +76,9 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str:
Using column name returns spec of the column.
"""
for column in columns_specs:
if column["displayName"] == column_name:
if column["display_name"] == column_name:
return extract_object_id(column)
return ""
raise Exception(f"Unknown target column: {column_name}")
# Example DAG to create dataset, train model_id and deploy it.

Просмотреть файл

@ -20,22 +20,23 @@
from typing import Dict, List, Optional, Sequence, Tuple, Union
from cached_property import cached_property
from google.api_core.operation import Operation
from google.api_core.retry import Retry
from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
from google.cloud.automl_v1beta1.types import (
from google.cloud.automl_v1beta1 import (
AutoMlClient,
BatchPredictInputConfig,
BatchPredictOutputConfig,
ColumnSpec,
Dataset,
ExamplePayload,
FieldMask,
ImageObjectDetectionModelDeploymentMetadata,
InputConfig,
Model,
Operation,
PredictionServiceClient,
PredictResponse,
TableSpec,
)
from google.protobuf.field_mask_pb2 import FieldMask
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
@ -123,9 +124,9 @@ class CloudAutoMLHook(GoogleBaseHook):
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
"""
client = self.get_conn()
parent = client.location_path(project_id, location)
parent = f"projects/{project_id}/locations/{location}"
return client.create_model(
parent=parent, model=model, retry=retry, timeout=timeout, metadata=metadata
request={'parent': parent, 'model': model}, retry=retry, timeout=timeout, metadata=metadata or ()
)
@GoogleBaseHook.fallback_to_default_project_id
@ -176,15 +177,17 @@ class CloudAutoMLHook(GoogleBaseHook):
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
"""
client = self.prediction_client
name = client.model_path(project=project_id, location=location, model=model_id)
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
result = client.batch_predict(
name=name,
input_config=input_config,
output_config=output_config,
params=params,
request={
'name': name,
'input_config': input_config,
'output_config': output_config,
'params': params,
},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result
@ -229,14 +232,12 @@ class CloudAutoMLHook(GoogleBaseHook):
:return: `google.cloud.automl_v1beta1.types.PredictResponse` instance
"""
client = self.prediction_client
name = client.model_path(project=project_id, location=location, model=model_id)
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
result = client.predict(
name=name,
payload=payload,
params=params,
request={'name': name, 'payload': payload, 'params': params},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result
@ -273,13 +274,12 @@ class CloudAutoMLHook(GoogleBaseHook):
:return: `google.cloud.automl_v1beta1.types.Dataset` instance.
"""
client = self.get_conn()
parent = client.location_path(project=project_id, location=location)
parent = f"projects/{project_id}/locations/{location}"
result = client.create_dataset(
parent=parent,
dataset=dataset,
request={'parent': parent, 'dataset': dataset},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result
@ -319,13 +319,12 @@ class CloudAutoMLHook(GoogleBaseHook):
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
"""
client = self.get_conn()
name = client.dataset_path(project=project_id, location=location, dataset=dataset_id)
name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
result = client.import_data(
name=name,
input_config=input_config,
request={'name': name, 'input_config': input_config},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result
@ -385,13 +384,10 @@ class CloudAutoMLHook(GoogleBaseHook):
table_spec=table_spec_id,
)
result = client.list_column_specs(
parent=parent,
field_mask=field_mask,
filter_=filter_,
page_size=page_size,
request={'parent': parent, 'field_mask': field_mask, 'filter': filter_, 'page_size': page_size},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result
@ -427,8 +423,10 @@ class CloudAutoMLHook(GoogleBaseHook):
:return: `google.cloud.automl_v1beta1.types.Model` instance.
"""
client = self.get_conn()
name = client.model_path(project=project_id, location=location, model=model_id)
result = client.get_model(name=name, retry=retry, timeout=timeout, metadata=metadata)
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
result = client.get_model(
request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
)
return result
@GoogleBaseHook.fallback_to_default_project_id
@ -463,8 +461,10 @@ class CloudAutoMLHook(GoogleBaseHook):
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance.
"""
client = self.get_conn()
name = client.model_path(project=project_id, location=location, model=model_id)
result = client.delete_model(name=name, retry=retry, timeout=timeout, metadata=metadata)
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
result = client.delete_model(
request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
)
return result
def update_dataset(
@ -497,11 +497,10 @@ class CloudAutoMLHook(GoogleBaseHook):
"""
client = self.get_conn()
result = client.update_dataset(
dataset=dataset,
update_mask=update_mask,
request={'dataset': dataset, 'update_mask': update_mask},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result
@ -547,13 +546,15 @@ class CloudAutoMLHook(GoogleBaseHook):
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance.
"""
client = self.get_conn()
name = client.model_path(project=project_id, location=location, model=model_id)
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
result = client.deploy_model(
name=name,
request={
'name': name,
'image_object_detection_model_deployment_metadata': image_detection_metadata,
},
retry=retry,
timeout=timeout,
metadata=metadata,
image_object_detection_model_deployment_metadata=image_detection_metadata,
metadata=metadata or (),
)
return result
@ -601,14 +602,12 @@ class CloudAutoMLHook(GoogleBaseHook):
of the response through the `options` parameter.
"""
client = self.get_conn()
parent = client.dataset_path(project=project_id, location=location, dataset=dataset_id)
parent = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
result = client.list_table_specs(
parent=parent,
filter_=filter_,
page_size=page_size,
request={'parent': parent, 'filter': filter_, 'page_size': page_size},
retry=retry,
timeout=timeout,
metadata=metadata,
metadata=metadata or (),
)
return result
@ -644,8 +643,10 @@ class CloudAutoMLHook(GoogleBaseHook):
of the response through the `options` parameter.
"""
client = self.get_conn()
parent = client.location_path(project=project_id, location=location)
result = client.list_datasets(parent=parent, retry=retry, timeout=timeout, metadata=metadata)
parent = f"projects/{project_id}/locations/{location}"
result = client.list_datasets(
request={'parent': parent}, retry=retry, timeout=timeout, metadata=metadata or ()
)
return result
@GoogleBaseHook.fallback_to_default_project_id
@ -680,6 +681,8 @@ class CloudAutoMLHook(GoogleBaseHook):
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
"""
client = self.get_conn()
name = client.dataset_path(project=project_id, location=location, dataset=dataset_id)
result = client.delete_dataset(name=name, retry=retry, timeout=timeout, metadata=metadata)
name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
result = client.delete_dataset(
request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
)
return result

Просмотреть файл

@ -22,7 +22,14 @@ import ast
from typing import Dict, List, Optional, Sequence, Tuple, Union
from google.api_core.retry import Retry
from google.protobuf.json_format import MessageToDict
from google.cloud.automl_v1beta1 import (
BatchPredictResult,
ColumnSpec,
Dataset,
Model,
PredictResponse,
TableSpec,
)
from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
@ -113,7 +120,7 @@ class AutoMLTrainModelOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
result = MessageToDict(operation.result())
result = Model.to_dict(operation.result())
model_id = hook.extract_object_id(result)
self.log.info("Model created: %s", model_id)
@ -212,7 +219,7 @@ class AutoMLPredictOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
return MessageToDict(result)
return PredictResponse.to_dict(result)
class AutoMLBatchPredictOperator(BaseOperator):
@ -324,7 +331,7 @@ class AutoMLBatchPredictOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
result = MessageToDict(operation.result())
result = BatchPredictResult.to_dict(operation.result())
self.log.info("Batch prediction ready.")
return result
@ -414,7 +421,7 @@ class AutoMLCreateDatasetOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
result = MessageToDict(result)
result = Dataset.to_dict(result)
dataset_id = hook.extract_object_id(result)
self.log.info("Creating completed. Dataset id: %s", dataset_id)
@ -513,9 +520,8 @@ class AutoMLImportDataOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
result = MessageToDict(operation.result())
operation.result()
self.log.info("Import completed")
return result
class AutoMLTablesListColumnSpecsOperator(BaseOperator):
@ -627,7 +633,7 @@ class AutoMLTablesListColumnSpecsOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
result = [MessageToDict(spec) for spec in page_iterator]
result = [ColumnSpec.to_dict(spec) for spec in page_iterator]
self.log.info("Columns specs obtained.")
return result
@ -718,7 +724,7 @@ class AutoMLTablesUpdateDatasetOperator(BaseOperator):
metadata=self.metadata,
)
self.log.info("Dataset updated.")
return MessageToDict(result)
return Dataset.to_dict(result)
class AutoMLGetModelOperator(BaseOperator):
@ -804,7 +810,7 @@ class AutoMLGetModelOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
return MessageToDict(result)
return Model.to_dict(result)
class AutoMLDeleteModelOperator(BaseOperator):
@ -890,8 +896,7 @@ class AutoMLDeleteModelOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
result = MessageToDict(operation.result())
return result
operation.result()
class AutoMLDeployModelOperator(BaseOperator):
@ -991,9 +996,8 @@ class AutoMLDeployModelOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
result = MessageToDict(operation.result())
operation.result()
self.log.info("Model deployed.")
return result
class AutoMLTablesListTableSpecsOperator(BaseOperator):
@ -1092,7 +1096,7 @@ class AutoMLTablesListTableSpecsOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
result = [MessageToDict(spec) for spec in page_iterator]
result = [TableSpec.to_dict(spec) for spec in page_iterator]
self.log.info(result)
self.log.info("Table specs obtained.")
return result
@ -1173,7 +1177,7 @@ class AutoMLListDatasetOperator(BaseOperator):
timeout=self.timeout,
metadata=self.metadata,
)
result = [MessageToDict(dataset) for dataset in page_iterator]
result = [Dataset.to_dict(dataset) for dataset in page_iterator]
self.log.info("Datasets obtained.")
self.xcom_push(

Просмотреть файл

@ -281,7 +281,7 @@ google = [
'google-api-python-client>=1.6.0,<2.0.0',
'google-auth>=1.0.0,<2.0.0',
'google-auth-httplib2>=0.0.1',
'google-cloud-automl>=0.4.0,<2.0.0',
'google-cloud-automl>=2.1.0,<3.0.0',
'google-cloud-bigquery-datatransfer>=3.0.0,<4.0.0',
'google-cloud-bigtable>=1.0.0,<2.0.0',
'google-cloud-container>=0.1.1,<2.0.0',

Просмотреть файл

@ -19,7 +19,7 @@
import unittest
from unittest import mock
from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
from google.cloud.automl_v1beta1 import AutoMlClient
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_no_default_project_id
@ -38,9 +38,9 @@ MODEL = {
"tables_model_metadata": {"train_budget_milli_node_hours": 1000},
}
LOCATION_PATH = AutoMlClient.location_path(GCP_PROJECT_ID, GCP_LOCATION)
MODEL_PATH = PredictionServiceClient.model_path(GCP_PROJECT_ID, GCP_LOCATION, MODEL_ID)
DATASET_PATH = AutoMlClient.dataset_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID)
LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}"
MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}"
DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}"
INPUT_CONFIG = {"input": "value"}
OUTPUT_CONFIG = {"output": "value"}
@ -81,7 +81,7 @@ class TestAuoMLHook(unittest.TestCase):
self.hook.create_model(model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
mock_create_model.assert_called_once_with(
parent=LOCATION_PATH, model=MODEL, retry=None, timeout=None, metadata=None
request=dict(parent=LOCATION_PATH, model=MODEL), retry=None, timeout=None, metadata=()
)
@mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.batch_predict")
@ -95,13 +95,12 @@ class TestAuoMLHook(unittest.TestCase):
)
mock_batch_predict.assert_called_once_with(
name=MODEL_PATH,
input_config=INPUT_CONFIG,
output_config=OUTPUT_CONFIG,
params=None,
request=dict(
name=MODEL_PATH, input_config=INPUT_CONFIG, output_config=OUTPUT_CONFIG, params=None
),
retry=None,
timeout=None,
metadata=None,
metadata=(),
)
@mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.predict")
@ -114,12 +113,10 @@ class TestAuoMLHook(unittest.TestCase):
)
mock_predict.assert_called_once_with(
name=MODEL_PATH,
payload=PAYLOAD,
params=None,
request=dict(name=MODEL_PATH, payload=PAYLOAD, params=None),
retry=None,
timeout=None,
metadata=None,
metadata=(),
)
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_dataset")
@ -127,11 +124,10 @@ class TestAuoMLHook(unittest.TestCase):
self.hook.create_dataset(dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
mock_create_dataset.assert_called_once_with(
parent=LOCATION_PATH,
dataset=DATASET,
request=dict(parent=LOCATION_PATH, dataset=DATASET),
retry=None,
timeout=None,
metadata=None,
metadata=(),
)
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.import_data")
@ -144,11 +140,10 @@ class TestAuoMLHook(unittest.TestCase):
)
mock_import_data.assert_called_once_with(
name=DATASET_PATH,
input_config=INPUT_CONFIG,
request=dict(name=DATASET_PATH, input_config=INPUT_CONFIG),
retry=None,
timeout=None,
metadata=None,
metadata=(),
)
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_column_specs")
@ -169,26 +164,27 @@ class TestAuoMLHook(unittest.TestCase):
parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID, table_spec)
mock_list_column_specs.assert_called_once_with(
parent=parent,
field_mask=MASK,
filter_=filter_,
page_size=page_size,
request=dict(parent=parent, field_mask=MASK, filter=filter_, page_size=page_size),
retry=None,
timeout=None,
metadata=None,
metadata=(),
)
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.get_model")
def test_get_model(self, mock_get_model):
self.hook.get_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
mock_get_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None)
mock_get_model.assert_called_once_with(
request=dict(name=MODEL_PATH), retry=None, timeout=None, metadata=()
)
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_model")
def test_delete_model(self, mock_delete_model):
self.hook.delete_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
mock_delete_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None)
mock_delete_model.assert_called_once_with(
request=dict(name=MODEL_PATH), retry=None, timeout=None, metadata=()
)
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.update_dataset")
def test_update_dataset(self, mock_update_dataset):
@ -198,7 +194,7 @@ class TestAuoMLHook(unittest.TestCase):
)
mock_update_dataset.assert_called_once_with(
dataset=DATASET, update_mask=MASK, retry=None, timeout=None, metadata=None
request=dict(dataset=DATASET, update_mask=MASK), retry=None, timeout=None, metadata=()
)
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.deploy_model")
@ -213,11 +209,13 @@ class TestAuoMLHook(unittest.TestCase):
)
mock_deploy_model.assert_called_once_with(
name=MODEL_PATH,
request=dict(
name=MODEL_PATH,
image_object_detection_model_deployment_metadata=image_detection_metadata,
),
retry=None,
timeout=None,
metadata=None,
image_object_detection_model_deployment_metadata=image_detection_metadata,
metadata=(),
)
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_table_specs")
@ -234,12 +232,10 @@ class TestAuoMLHook(unittest.TestCase):
)
mock_list_table_specs.assert_called_once_with(
parent=DATASET_PATH,
filter_=filter_,
page_size=page_size,
request=dict(parent=DATASET_PATH, filter=filter_, page_size=page_size),
retry=None,
timeout=None,
metadata=None,
metadata=(),
)
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_datasets")
@ -247,7 +243,7 @@ class TestAuoMLHook(unittest.TestCase):
self.hook.list_datasets(location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
mock_list_datasets.assert_called_once_with(
parent=LOCATION_PATH, retry=None, timeout=None, metadata=None
request=dict(parent=LOCATION_PATH), retry=None, timeout=None, metadata=()
)
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_dataset")
@ -255,5 +251,5 @@ class TestAuoMLHook(unittest.TestCase):
self.hook.delete_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
mock_delete_dataset.assert_called_once_with(
name=DATASET_PATH, retry=None, timeout=None, metadata=None
request=dict(name=DATASET_PATH), retry=None, timeout=None, metadata=()
)

Просмотреть файл

@ -20,8 +20,9 @@ import copy
import unittest
from unittest import mock
from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
from google.cloud.automl_v1beta1 import BatchPredictResult, Dataset, Model, PredictResponse
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.operators.automl import (
AutoMLBatchPredictOperator,
AutoMLCreateDatasetOperator,
@ -43,7 +44,7 @@ TASK_ID = "test-automl-hook"
GCP_PROJECT_ID = "test-project"
GCP_LOCATION = "test-location"
MODEL_NAME = "test_model"
MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152"
MODEL_ID = "TBL9195602771183665152"
DATASET_ID = "TBL123456789"
MODEL = {
"display_name": MODEL_NAME,
@ -51,8 +52,9 @@ MODEL = {
"tables_model_metadata": {"train_budget_milli_node_hours": 1000},
}
LOCATION_PATH = AutoMlClient.location_path(GCP_PROJECT_ID, GCP_LOCATION)
MODEL_PATH = PredictionServiceClient.model_path(GCP_PROJECT_ID, GCP_LOCATION, MODEL_ID)
LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}"
MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}"
DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}"
INPUT_CONFIG = {"input": "value"}
OUTPUT_CONFIG = {"output": "value"}
@ -60,12 +62,15 @@ PAYLOAD = {"test": "payload"}
DATASET = {"dataset_id": "data"}
MASK = {"field": "mask"}
extract_object_id = CloudAutoMLHook.extract_object_id
class TestAutoMLTrainModelOperator(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLTrainModelOperator.xcom_push")
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
def test_execute(self, mock_hook, mock_xcom):
mock_hook.return_value.extract_object_id.return_value = MODEL_ID
mock_hook.return_value.create_model.return_value.result.return_value = Model(name=MODEL_PATH)
mock_hook.return_value.extract_object_id = extract_object_id
op = AutoMLTrainModelOperator(
model=MODEL,
location=GCP_LOCATION,
@ -87,6 +92,9 @@ class TestAutoMLTrainModelOperator(unittest.TestCase):
class TestAutoMLBatchPredictOperator(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
def test_execute(self, mock_hook):
mock_hook.return_value.batch_predict.return_value.result.return_value = BatchPredictResult()
mock_hook.return_value.extract_object_id = extract_object_id
op = AutoMLBatchPredictOperator(
model_id=MODEL_ID,
location=GCP_LOCATION,
@ -113,6 +121,8 @@ class TestAutoMLBatchPredictOperator(unittest.TestCase):
class TestAutoMLPredictOperator(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
def test_execute(self, mock_hook):
mock_hook.return_value.predict.return_value = PredictResponse()
op = AutoMLPredictOperator(
model_id=MODEL_ID,
location=GCP_LOCATION,
@ -137,7 +147,9 @@ class TestAutoMLCreateImportOperator(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator.xcom_push")
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
def test_execute(self, mock_hook, mock_xcom):
mock_hook.return_value.extract_object_id.return_value = DATASET_ID
mock_hook.return_value.create_dataset.return_value = Dataset(name=DATASET_PATH)
mock_hook.return_value.extract_object_id = extract_object_id
op = AutoMLCreateDatasetOperator(
dataset=DATASET,
location=GCP_LOCATION,
@ -191,6 +203,8 @@ class TestAutoMLListColumnsSpecsOperator(unittest.TestCase):
class TestAutoMLUpdateDatasetOperator(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
def test_execute(self, mock_hook):
mock_hook.return_value.update_dataset.return_value = Dataset(name=DATASET_PATH)
dataset = copy.deepcopy(DATASET)
dataset["name"] = DATASET_ID
@ -213,6 +227,9 @@ class TestAutoMLUpdateDatasetOperator(unittest.TestCase):
class TestAutoMLGetModelOperator(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
def test_execute(self, mock_hook):
mock_hook.return_value.get_model.return_value = Model(name=MODEL_PATH)
mock_hook.return_value.extract_object_id = extract_object_id
op = AutoMLGetModelOperator(
model_id=MODEL_ID,
location=GCP_LOCATION,