Support google-cloud-automl >=2.1.0 (#13505)
This commit is contained in:
Родитель
947dbb73bb
Коммит
a6f999b62e
|
@ -29,6 +29,7 @@ Details are covered in the UPDATING.md files for each library, but there are som
|
|||
|
||||
| Library name | Previous constraints | Current constraints | |
|
||||
| --- | --- | --- | --- |
|
||||
| [``google-cloud-automl``](https://pypi.org/project/google-cloud-automl/) | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-automl/blob/master/UPGRADING.md) |
|
||||
| [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) |
|
||||
| [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) |
|
||||
| [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) |
|
||||
|
|
|
@ -47,7 +47,7 @@ GCP_AUTOML_LOCATION = os.environ.get("GCP_AUTOML_LOCATION", "us-central1")
|
|||
GCP_AUTOML_DATASET_BUCKET = os.environ.get(
|
||||
"GCP_AUTOML_DATASET_BUCKET", "gs://cloud-ml-tables-data/bank-marketing.csv"
|
||||
)
|
||||
TARGET = os.environ.get("GCP_AUTOML_TARGET", "Class")
|
||||
TARGET = os.environ.get("GCP_AUTOML_TARGET", "Deposit")
|
||||
|
||||
# Example values
|
||||
MODEL_ID = "TBL123456"
|
||||
|
@ -76,9 +76,9 @@ def get_target_column_spec(columns_specs: List[Dict], column_name: str) -> str:
|
|||
Using column name returns spec of the column.
|
||||
"""
|
||||
for column in columns_specs:
|
||||
if column["displayName"] == column_name:
|
||||
if column["display_name"] == column_name:
|
||||
return extract_object_id(column)
|
||||
return ""
|
||||
raise Exception(f"Unknown target column: {column_name}")
|
||||
|
||||
|
||||
# Example DAG to create dataset, train model_id and deploy it.
|
||||
|
|
|
@ -20,22 +20,23 @@
|
|||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from cached_property import cached_property
|
||||
from google.api_core.operation import Operation
|
||||
from google.api_core.retry import Retry
|
||||
from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
|
||||
from google.cloud.automl_v1beta1.types import (
|
||||
from google.cloud.automl_v1beta1 import (
|
||||
AutoMlClient,
|
||||
BatchPredictInputConfig,
|
||||
BatchPredictOutputConfig,
|
||||
ColumnSpec,
|
||||
Dataset,
|
||||
ExamplePayload,
|
||||
FieldMask,
|
||||
ImageObjectDetectionModelDeploymentMetadata,
|
||||
InputConfig,
|
||||
Model,
|
||||
Operation,
|
||||
PredictionServiceClient,
|
||||
PredictResponse,
|
||||
TableSpec,
|
||||
)
|
||||
from google.protobuf.field_mask_pb2 import FieldMask
|
||||
|
||||
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
||||
|
||||
|
@ -123,9 +124,9 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|||
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
|
||||
"""
|
||||
client = self.get_conn()
|
||||
parent = client.location_path(project_id, location)
|
||||
parent = f"projects/{project_id}/locations/{location}"
|
||||
return client.create_model(
|
||||
parent=parent, model=model, retry=retry, timeout=timeout, metadata=metadata
|
||||
request={'parent': parent, 'model': model}, retry=retry, timeout=timeout, metadata=metadata or ()
|
||||
)
|
||||
|
||||
@GoogleBaseHook.fallback_to_default_project_id
|
||||
|
@ -176,15 +177,17 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|||
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
|
||||
"""
|
||||
client = self.prediction_client
|
||||
name = client.model_path(project=project_id, location=location, model=model_id)
|
||||
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
|
||||
result = client.batch_predict(
|
||||
name=name,
|
||||
input_config=input_config,
|
||||
output_config=output_config,
|
||||
params=params,
|
||||
request={
|
||||
'name': name,
|
||||
'input_config': input_config,
|
||||
'output_config': output_config,
|
||||
'params': params,
|
||||
},
|
||||
retry=retry,
|
||||
timeout=timeout,
|
||||
metadata=metadata,
|
||||
metadata=metadata or (),
|
||||
)
|
||||
return result
|
||||
|
||||
|
@ -229,14 +232,12 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|||
:return: `google.cloud.automl_v1beta1.types.PredictResponse` instance
|
||||
"""
|
||||
client = self.prediction_client
|
||||
name = client.model_path(project=project_id, location=location, model=model_id)
|
||||
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
|
||||
result = client.predict(
|
||||
name=name,
|
||||
payload=payload,
|
||||
params=params,
|
||||
request={'name': name, 'payload': payload, 'params': params},
|
||||
retry=retry,
|
||||
timeout=timeout,
|
||||
metadata=metadata,
|
||||
metadata=metadata or (),
|
||||
)
|
||||
return result
|
||||
|
||||
|
@ -273,13 +274,12 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|||
:return: `google.cloud.automl_v1beta1.types.Dataset` instance.
|
||||
"""
|
||||
client = self.get_conn()
|
||||
parent = client.location_path(project=project_id, location=location)
|
||||
parent = f"projects/{project_id}/locations/{location}"
|
||||
result = client.create_dataset(
|
||||
parent=parent,
|
||||
dataset=dataset,
|
||||
request={'parent': parent, 'dataset': dataset},
|
||||
retry=retry,
|
||||
timeout=timeout,
|
||||
metadata=metadata,
|
||||
metadata=metadata or (),
|
||||
)
|
||||
return result
|
||||
|
||||
|
@ -319,13 +319,12 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|||
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
|
||||
"""
|
||||
client = self.get_conn()
|
||||
name = client.dataset_path(project=project_id, location=location, dataset=dataset_id)
|
||||
name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
|
||||
result = client.import_data(
|
||||
name=name,
|
||||
input_config=input_config,
|
||||
request={'name': name, 'input_config': input_config},
|
||||
retry=retry,
|
||||
timeout=timeout,
|
||||
metadata=metadata,
|
||||
metadata=metadata or (),
|
||||
)
|
||||
return result
|
||||
|
||||
|
@ -385,13 +384,10 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|||
table_spec=table_spec_id,
|
||||
)
|
||||
result = client.list_column_specs(
|
||||
parent=parent,
|
||||
field_mask=field_mask,
|
||||
filter_=filter_,
|
||||
page_size=page_size,
|
||||
request={'parent': parent, 'field_mask': field_mask, 'filter': filter_, 'page_size': page_size},
|
||||
retry=retry,
|
||||
timeout=timeout,
|
||||
metadata=metadata,
|
||||
metadata=metadata or (),
|
||||
)
|
||||
return result
|
||||
|
||||
|
@ -427,8 +423,10 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|||
:return: `google.cloud.automl_v1beta1.types.Model` instance.
|
||||
"""
|
||||
client = self.get_conn()
|
||||
name = client.model_path(project=project_id, location=location, model=model_id)
|
||||
result = client.get_model(name=name, retry=retry, timeout=timeout, metadata=metadata)
|
||||
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
|
||||
result = client.get_model(
|
||||
request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
|
||||
)
|
||||
return result
|
||||
|
||||
@GoogleBaseHook.fallback_to_default_project_id
|
||||
|
@ -463,8 +461,10 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|||
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance.
|
||||
"""
|
||||
client = self.get_conn()
|
||||
name = client.model_path(project=project_id, location=location, model=model_id)
|
||||
result = client.delete_model(name=name, retry=retry, timeout=timeout, metadata=metadata)
|
||||
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
|
||||
result = client.delete_model(
|
||||
request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
|
||||
)
|
||||
return result
|
||||
|
||||
def update_dataset(
|
||||
|
@ -497,11 +497,10 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|||
"""
|
||||
client = self.get_conn()
|
||||
result = client.update_dataset(
|
||||
dataset=dataset,
|
||||
update_mask=update_mask,
|
||||
request={'dataset': dataset, 'update_mask': update_mask},
|
||||
retry=retry,
|
||||
timeout=timeout,
|
||||
metadata=metadata,
|
||||
metadata=metadata or (),
|
||||
)
|
||||
return result
|
||||
|
||||
|
@ -547,13 +546,15 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|||
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance.
|
||||
"""
|
||||
client = self.get_conn()
|
||||
name = client.model_path(project=project_id, location=location, model=model_id)
|
||||
name = f"projects/{project_id}/locations/{location}/models/{model_id}"
|
||||
result = client.deploy_model(
|
||||
name=name,
|
||||
request={
|
||||
'name': name,
|
||||
'image_object_detection_model_deployment_metadata': image_detection_metadata,
|
||||
},
|
||||
retry=retry,
|
||||
timeout=timeout,
|
||||
metadata=metadata,
|
||||
image_object_detection_model_deployment_metadata=image_detection_metadata,
|
||||
metadata=metadata or (),
|
||||
)
|
||||
return result
|
||||
|
||||
|
@ -601,14 +602,12 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|||
of the response through the `options` parameter.
|
||||
"""
|
||||
client = self.get_conn()
|
||||
parent = client.dataset_path(project=project_id, location=location, dataset=dataset_id)
|
||||
parent = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
|
||||
result = client.list_table_specs(
|
||||
parent=parent,
|
||||
filter_=filter_,
|
||||
page_size=page_size,
|
||||
request={'parent': parent, 'filter': filter_, 'page_size': page_size},
|
||||
retry=retry,
|
||||
timeout=timeout,
|
||||
metadata=metadata,
|
||||
metadata=metadata or (),
|
||||
)
|
||||
return result
|
||||
|
||||
|
@ -644,8 +643,10 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|||
of the response through the `options` parameter.
|
||||
"""
|
||||
client = self.get_conn()
|
||||
parent = client.location_path(project=project_id, location=location)
|
||||
result = client.list_datasets(parent=parent, retry=retry, timeout=timeout, metadata=metadata)
|
||||
parent = f"projects/{project_id}/locations/{location}"
|
||||
result = client.list_datasets(
|
||||
request={'parent': parent}, retry=retry, timeout=timeout, metadata=metadata or ()
|
||||
)
|
||||
return result
|
||||
|
||||
@GoogleBaseHook.fallback_to_default_project_id
|
||||
|
@ -680,6 +681,8 @@ class CloudAutoMLHook(GoogleBaseHook):
|
|||
:return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
|
||||
"""
|
||||
client = self.get_conn()
|
||||
name = client.dataset_path(project=project_id, location=location, dataset=dataset_id)
|
||||
result = client.delete_dataset(name=name, retry=retry, timeout=timeout, metadata=metadata)
|
||||
name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
|
||||
result = client.delete_dataset(
|
||||
request={'name': name}, retry=retry, timeout=timeout, metadata=metadata or ()
|
||||
)
|
||||
return result
|
||||
|
|
|
@ -22,7 +22,14 @@ import ast
|
|||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from google.api_core.retry import Retry
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
from google.cloud.automl_v1beta1 import (
|
||||
BatchPredictResult,
|
||||
ColumnSpec,
|
||||
Dataset,
|
||||
Model,
|
||||
PredictResponse,
|
||||
TableSpec,
|
||||
)
|
||||
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
|
||||
|
@ -113,7 +120,7 @@ class AutoMLTrainModelOperator(BaseOperator):
|
|||
timeout=self.timeout,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
result = MessageToDict(operation.result())
|
||||
result = Model.to_dict(operation.result())
|
||||
model_id = hook.extract_object_id(result)
|
||||
self.log.info("Model created: %s", model_id)
|
||||
|
||||
|
@ -212,7 +219,7 @@ class AutoMLPredictOperator(BaseOperator):
|
|||
timeout=self.timeout,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
return MessageToDict(result)
|
||||
return PredictResponse.to_dict(result)
|
||||
|
||||
|
||||
class AutoMLBatchPredictOperator(BaseOperator):
|
||||
|
@ -324,7 +331,7 @@ class AutoMLBatchPredictOperator(BaseOperator):
|
|||
timeout=self.timeout,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
result = MessageToDict(operation.result())
|
||||
result = BatchPredictResult.to_dict(operation.result())
|
||||
self.log.info("Batch prediction ready.")
|
||||
return result
|
||||
|
||||
|
@ -414,7 +421,7 @@ class AutoMLCreateDatasetOperator(BaseOperator):
|
|||
timeout=self.timeout,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
result = MessageToDict(result)
|
||||
result = Dataset.to_dict(result)
|
||||
dataset_id = hook.extract_object_id(result)
|
||||
self.log.info("Creating completed. Dataset id: %s", dataset_id)
|
||||
|
||||
|
@ -513,9 +520,8 @@ class AutoMLImportDataOperator(BaseOperator):
|
|||
timeout=self.timeout,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
result = MessageToDict(operation.result())
|
||||
operation.result()
|
||||
self.log.info("Import completed")
|
||||
return result
|
||||
|
||||
|
||||
class AutoMLTablesListColumnSpecsOperator(BaseOperator):
|
||||
|
@ -627,7 +633,7 @@ class AutoMLTablesListColumnSpecsOperator(BaseOperator):
|
|||
timeout=self.timeout,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
result = [MessageToDict(spec) for spec in page_iterator]
|
||||
result = [ColumnSpec.to_dict(spec) for spec in page_iterator]
|
||||
self.log.info("Columns specs obtained.")
|
||||
|
||||
return result
|
||||
|
@ -718,7 +724,7 @@ class AutoMLTablesUpdateDatasetOperator(BaseOperator):
|
|||
metadata=self.metadata,
|
||||
)
|
||||
self.log.info("Dataset updated.")
|
||||
return MessageToDict(result)
|
||||
return Dataset.to_dict(result)
|
||||
|
||||
|
||||
class AutoMLGetModelOperator(BaseOperator):
|
||||
|
@ -804,7 +810,7 @@ class AutoMLGetModelOperator(BaseOperator):
|
|||
timeout=self.timeout,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
return MessageToDict(result)
|
||||
return Model.to_dict(result)
|
||||
|
||||
|
||||
class AutoMLDeleteModelOperator(BaseOperator):
|
||||
|
@ -890,8 +896,7 @@ class AutoMLDeleteModelOperator(BaseOperator):
|
|||
timeout=self.timeout,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
result = MessageToDict(operation.result())
|
||||
return result
|
||||
operation.result()
|
||||
|
||||
|
||||
class AutoMLDeployModelOperator(BaseOperator):
|
||||
|
@ -991,9 +996,8 @@ class AutoMLDeployModelOperator(BaseOperator):
|
|||
timeout=self.timeout,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
result = MessageToDict(operation.result())
|
||||
operation.result()
|
||||
self.log.info("Model deployed.")
|
||||
return result
|
||||
|
||||
|
||||
class AutoMLTablesListTableSpecsOperator(BaseOperator):
|
||||
|
@ -1092,7 +1096,7 @@ class AutoMLTablesListTableSpecsOperator(BaseOperator):
|
|||
timeout=self.timeout,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
result = [MessageToDict(spec) for spec in page_iterator]
|
||||
result = [TableSpec.to_dict(spec) for spec in page_iterator]
|
||||
self.log.info(result)
|
||||
self.log.info("Table specs obtained.")
|
||||
return result
|
||||
|
@ -1173,7 +1177,7 @@ class AutoMLListDatasetOperator(BaseOperator):
|
|||
timeout=self.timeout,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
result = [MessageToDict(dataset) for dataset in page_iterator]
|
||||
result = [Dataset.to_dict(dataset) for dataset in page_iterator]
|
||||
self.log.info("Datasets obtained.")
|
||||
|
||||
self.xcom_push(
|
||||
|
|
2
setup.py
2
setup.py
|
@ -281,7 +281,7 @@ google = [
|
|||
'google-api-python-client>=1.6.0,<2.0.0',
|
||||
'google-auth>=1.0.0,<2.0.0',
|
||||
'google-auth-httplib2>=0.0.1',
|
||||
'google-cloud-automl>=0.4.0,<2.0.0',
|
||||
'google-cloud-automl>=2.1.0,<3.0.0',
|
||||
'google-cloud-bigquery-datatransfer>=3.0.0,<4.0.0',
|
||||
'google-cloud-bigtable>=1.0.0,<2.0.0',
|
||||
'google-cloud-container>=0.1.1,<2.0.0',
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
|
||||
from google.cloud.automl_v1beta1 import AutoMlClient
|
||||
|
||||
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
|
||||
from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_no_default_project_id
|
||||
|
@ -38,9 +38,9 @@ MODEL = {
|
|||
"tables_model_metadata": {"train_budget_milli_node_hours": 1000},
|
||||
}
|
||||
|
||||
LOCATION_PATH = AutoMlClient.location_path(GCP_PROJECT_ID, GCP_LOCATION)
|
||||
MODEL_PATH = PredictionServiceClient.model_path(GCP_PROJECT_ID, GCP_LOCATION, MODEL_ID)
|
||||
DATASET_PATH = AutoMlClient.dataset_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID)
|
||||
LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}"
|
||||
MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}"
|
||||
DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}"
|
||||
|
||||
INPUT_CONFIG = {"input": "value"}
|
||||
OUTPUT_CONFIG = {"output": "value"}
|
||||
|
@ -81,7 +81,7 @@ class TestAuoMLHook(unittest.TestCase):
|
|||
self.hook.create_model(model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
|
||||
|
||||
mock_create_model.assert_called_once_with(
|
||||
parent=LOCATION_PATH, model=MODEL, retry=None, timeout=None, metadata=None
|
||||
request=dict(parent=LOCATION_PATH, model=MODEL), retry=None, timeout=None, metadata=()
|
||||
)
|
||||
|
||||
@mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.batch_predict")
|
||||
|
@ -95,13 +95,12 @@ class TestAuoMLHook(unittest.TestCase):
|
|||
)
|
||||
|
||||
mock_batch_predict.assert_called_once_with(
|
||||
name=MODEL_PATH,
|
||||
input_config=INPUT_CONFIG,
|
||||
output_config=OUTPUT_CONFIG,
|
||||
params=None,
|
||||
request=dict(
|
||||
name=MODEL_PATH, input_config=INPUT_CONFIG, output_config=OUTPUT_CONFIG, params=None
|
||||
),
|
||||
retry=None,
|
||||
timeout=None,
|
||||
metadata=None,
|
||||
metadata=(),
|
||||
)
|
||||
|
||||
@mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.predict")
|
||||
|
@ -114,12 +113,10 @@ class TestAuoMLHook(unittest.TestCase):
|
|||
)
|
||||
|
||||
mock_predict.assert_called_once_with(
|
||||
name=MODEL_PATH,
|
||||
payload=PAYLOAD,
|
||||
params=None,
|
||||
request=dict(name=MODEL_PATH, payload=PAYLOAD, params=None),
|
||||
retry=None,
|
||||
timeout=None,
|
||||
metadata=None,
|
||||
metadata=(),
|
||||
)
|
||||
|
||||
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_dataset")
|
||||
|
@ -127,11 +124,10 @@ class TestAuoMLHook(unittest.TestCase):
|
|||
self.hook.create_dataset(dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
|
||||
|
||||
mock_create_dataset.assert_called_once_with(
|
||||
parent=LOCATION_PATH,
|
||||
dataset=DATASET,
|
||||
request=dict(parent=LOCATION_PATH, dataset=DATASET),
|
||||
retry=None,
|
||||
timeout=None,
|
||||
metadata=None,
|
||||
metadata=(),
|
||||
)
|
||||
|
||||
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.import_data")
|
||||
|
@ -144,11 +140,10 @@ class TestAuoMLHook(unittest.TestCase):
|
|||
)
|
||||
|
||||
mock_import_data.assert_called_once_with(
|
||||
name=DATASET_PATH,
|
||||
input_config=INPUT_CONFIG,
|
||||
request=dict(name=DATASET_PATH, input_config=INPUT_CONFIG),
|
||||
retry=None,
|
||||
timeout=None,
|
||||
metadata=None,
|
||||
metadata=(),
|
||||
)
|
||||
|
||||
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_column_specs")
|
||||
|
@ -169,26 +164,27 @@ class TestAuoMLHook(unittest.TestCase):
|
|||
|
||||
parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID, table_spec)
|
||||
mock_list_column_specs.assert_called_once_with(
|
||||
parent=parent,
|
||||
field_mask=MASK,
|
||||
filter_=filter_,
|
||||
page_size=page_size,
|
||||
request=dict(parent=parent, field_mask=MASK, filter=filter_, page_size=page_size),
|
||||
retry=None,
|
||||
timeout=None,
|
||||
metadata=None,
|
||||
metadata=(),
|
||||
)
|
||||
|
||||
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.get_model")
|
||||
def test_get_model(self, mock_get_model):
|
||||
self.hook.get_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
|
||||
|
||||
mock_get_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None)
|
||||
mock_get_model.assert_called_once_with(
|
||||
request=dict(name=MODEL_PATH), retry=None, timeout=None, metadata=()
|
||||
)
|
||||
|
||||
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_model")
|
||||
def test_delete_model(self, mock_delete_model):
|
||||
self.hook.delete_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
|
||||
|
||||
mock_delete_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None)
|
||||
mock_delete_model.assert_called_once_with(
|
||||
request=dict(name=MODEL_PATH), retry=None, timeout=None, metadata=()
|
||||
)
|
||||
|
||||
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.update_dataset")
|
||||
def test_update_dataset(self, mock_update_dataset):
|
||||
|
@ -198,7 +194,7 @@ class TestAuoMLHook(unittest.TestCase):
|
|||
)
|
||||
|
||||
mock_update_dataset.assert_called_once_with(
|
||||
dataset=DATASET, update_mask=MASK, retry=None, timeout=None, metadata=None
|
||||
request=dict(dataset=DATASET, update_mask=MASK), retry=None, timeout=None, metadata=()
|
||||
)
|
||||
|
||||
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.deploy_model")
|
||||
|
@ -213,11 +209,13 @@ class TestAuoMLHook(unittest.TestCase):
|
|||
)
|
||||
|
||||
mock_deploy_model.assert_called_once_with(
|
||||
name=MODEL_PATH,
|
||||
request=dict(
|
||||
name=MODEL_PATH,
|
||||
image_object_detection_model_deployment_metadata=image_detection_metadata,
|
||||
),
|
||||
retry=None,
|
||||
timeout=None,
|
||||
metadata=None,
|
||||
image_object_detection_model_deployment_metadata=image_detection_metadata,
|
||||
metadata=(),
|
||||
)
|
||||
|
||||
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_table_specs")
|
||||
|
@ -234,12 +232,10 @@ class TestAuoMLHook(unittest.TestCase):
|
|||
)
|
||||
|
||||
mock_list_table_specs.assert_called_once_with(
|
||||
parent=DATASET_PATH,
|
||||
filter_=filter_,
|
||||
page_size=page_size,
|
||||
request=dict(parent=DATASET_PATH, filter=filter_, page_size=page_size),
|
||||
retry=None,
|
||||
timeout=None,
|
||||
metadata=None,
|
||||
metadata=(),
|
||||
)
|
||||
|
||||
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_datasets")
|
||||
|
@ -247,7 +243,7 @@ class TestAuoMLHook(unittest.TestCase):
|
|||
self.hook.list_datasets(location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
|
||||
|
||||
mock_list_datasets.assert_called_once_with(
|
||||
parent=LOCATION_PATH, retry=None, timeout=None, metadata=None
|
||||
request=dict(parent=LOCATION_PATH), retry=None, timeout=None, metadata=()
|
||||
)
|
||||
|
||||
@mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_dataset")
|
||||
|
@ -255,5 +251,5 @@ class TestAuoMLHook(unittest.TestCase):
|
|||
self.hook.delete_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
|
||||
|
||||
mock_delete_dataset.assert_called_once_with(
|
||||
name=DATASET_PATH, retry=None, timeout=None, metadata=None
|
||||
request=dict(name=DATASET_PATH), retry=None, timeout=None, metadata=()
|
||||
)
|
||||
|
|
|
@ -20,8 +20,9 @@ import copy
|
|||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
from google.cloud.automl_v1beta1 import AutoMlClient, PredictionServiceClient
|
||||
from google.cloud.automl_v1beta1 import BatchPredictResult, Dataset, Model, PredictResponse
|
||||
|
||||
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
|
||||
from airflow.providers.google.cloud.operators.automl import (
|
||||
AutoMLBatchPredictOperator,
|
||||
AutoMLCreateDatasetOperator,
|
||||
|
@ -43,7 +44,7 @@ TASK_ID = "test-automl-hook"
|
|||
GCP_PROJECT_ID = "test-project"
|
||||
GCP_LOCATION = "test-location"
|
||||
MODEL_NAME = "test_model"
|
||||
MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152"
|
||||
MODEL_ID = "TBL9195602771183665152"
|
||||
DATASET_ID = "TBL123456789"
|
||||
MODEL = {
|
||||
"display_name": MODEL_NAME,
|
||||
|
@ -51,8 +52,9 @@ MODEL = {
|
|||
"tables_model_metadata": {"train_budget_milli_node_hours": 1000},
|
||||
}
|
||||
|
||||
LOCATION_PATH = AutoMlClient.location_path(GCP_PROJECT_ID, GCP_LOCATION)
|
||||
MODEL_PATH = PredictionServiceClient.model_path(GCP_PROJECT_ID, GCP_LOCATION, MODEL_ID)
|
||||
LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}"
|
||||
MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}"
|
||||
DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}"
|
||||
|
||||
INPUT_CONFIG = {"input": "value"}
|
||||
OUTPUT_CONFIG = {"output": "value"}
|
||||
|
@ -60,12 +62,15 @@ PAYLOAD = {"test": "payload"}
|
|||
DATASET = {"dataset_id": "data"}
|
||||
MASK = {"field": "mask"}
|
||||
|
||||
extract_object_id = CloudAutoMLHook.extract_object_id
|
||||
|
||||
|
||||
class TestAutoMLTrainModelOperator(unittest.TestCase):
|
||||
@mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLTrainModelOperator.xcom_push")
|
||||
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
|
||||
def test_execute(self, mock_hook, mock_xcom):
|
||||
mock_hook.return_value.extract_object_id.return_value = MODEL_ID
|
||||
mock_hook.return_value.create_model.return_value.result.return_value = Model(name=MODEL_PATH)
|
||||
mock_hook.return_value.extract_object_id = extract_object_id
|
||||
op = AutoMLTrainModelOperator(
|
||||
model=MODEL,
|
||||
location=GCP_LOCATION,
|
||||
|
@ -87,6 +92,9 @@ class TestAutoMLTrainModelOperator(unittest.TestCase):
|
|||
class TestAutoMLBatchPredictOperator(unittest.TestCase):
|
||||
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
|
||||
def test_execute(self, mock_hook):
|
||||
mock_hook.return_value.batch_predict.return_value.result.return_value = BatchPredictResult()
|
||||
mock_hook.return_value.extract_object_id = extract_object_id
|
||||
|
||||
op = AutoMLBatchPredictOperator(
|
||||
model_id=MODEL_ID,
|
||||
location=GCP_LOCATION,
|
||||
|
@ -113,6 +121,8 @@ class TestAutoMLBatchPredictOperator(unittest.TestCase):
|
|||
class TestAutoMLPredictOperator(unittest.TestCase):
|
||||
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
|
||||
def test_execute(self, mock_hook):
|
||||
mock_hook.return_value.predict.return_value = PredictResponse()
|
||||
|
||||
op = AutoMLPredictOperator(
|
||||
model_id=MODEL_ID,
|
||||
location=GCP_LOCATION,
|
||||
|
@ -137,7 +147,9 @@ class TestAutoMLCreateImportOperator(unittest.TestCase):
|
|||
@mock.patch("airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator.xcom_push")
|
||||
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
|
||||
def test_execute(self, mock_hook, mock_xcom):
|
||||
mock_hook.return_value.extract_object_id.return_value = DATASET_ID
|
||||
mock_hook.return_value.create_dataset.return_value = Dataset(name=DATASET_PATH)
|
||||
mock_hook.return_value.extract_object_id = extract_object_id
|
||||
|
||||
op = AutoMLCreateDatasetOperator(
|
||||
dataset=DATASET,
|
||||
location=GCP_LOCATION,
|
||||
|
@ -191,6 +203,8 @@ class TestAutoMLListColumnsSpecsOperator(unittest.TestCase):
|
|||
class TestAutoMLUpdateDatasetOperator(unittest.TestCase):
|
||||
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
|
||||
def test_execute(self, mock_hook):
|
||||
mock_hook.return_value.update_dataset.return_value = Dataset(name=DATASET_PATH)
|
||||
|
||||
dataset = copy.deepcopy(DATASET)
|
||||
dataset["name"] = DATASET_ID
|
||||
|
||||
|
@ -213,6 +227,9 @@ class TestAutoMLUpdateDatasetOperator(unittest.TestCase):
|
|||
class TestAutoMLGetModelOperator(unittest.TestCase):
|
||||
@mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
|
||||
def test_execute(self, mock_hook):
|
||||
mock_hook.return_value.get_model.return_value = Model(name=MODEL_PATH)
|
||||
mock_hook.return_value.extract_object_id = extract_object_id
|
||||
|
||||
op = AutoMLGetModelOperator(
|
||||
model_id=MODEL_ID,
|
||||
location=GCP_LOCATION,
|
||||
|
|
Загрузка…
Ссылка в новой задаче