diff --git a/airflow/gcp/hooks/kubernetes_engine.py b/airflow/gcp/hooks/kubernetes_engine.py index 8dd605004c..8c60c4a2a6 100644 --- a/airflow/gcp/hooks/kubernetes_engine.py +++ b/airflow/gcp/hooks/kubernetes_engine.py @@ -21,7 +21,6 @@ This module contains a Google Kubernetes Engine Hook. """ -import json import time from typing import Dict, Union, Optional @@ -32,7 +31,7 @@ from google.api_core.retry import Retry from google.cloud import container_v1, exceptions from google.cloud.container_v1.gapic.enums import Operation from google.cloud.container_v1.types import Cluster -from google.protobuf import json_format +from google.protobuf.json_format import ParseDict from airflow import AirflowException, version from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook @@ -73,22 +72,6 @@ class GKEClusterHook(GoogleCloudBaseHook): ) return self._client - @staticmethod - def _dict_to_proto(py_dict: Dict, proto): - """ - Converts a python dictionary to the proto supplied - - :param py_dict: The dictionary to convert - :type py_dict: dict - :param proto: The proto object to merge with dictionary - :type proto: protobuf - :return: A parsed python dictionary in provided proto format - :raises: - ParseError: On JSON parsing problems. - """ - dict_json_str = json.dumps(py_dict) - return json_format.Parse(dict_json_str, proto) - def wait_for_operation(self, operation: Operation, project_id: str = None) -> Operation: """ Given an operation, continuously fetches the status from Google Cloud until either @@ -227,7 +210,7 @@ class GKEClusterHook(GoogleCloudBaseHook): if isinstance(cluster, dict): cluster_proto = Cluster() - cluster = self._dict_to_proto(py_dict=cluster, proto=cluster_proto) + cluster = ParseDict(cluster, cluster_proto) elif not isinstance(cluster, Cluster): raise AirflowException( "cluster is not instance of Cluster proto or python dict") diff --git a/tests/gcp/hooks/test_kubernetes_engine.py b/tests/gcp/hooks/test_kubernetes_engine.py index 8f8edaaba4..0fbe98961d 100644 --- a/tests/gcp/hooks/test_kubernetes_engine.py +++ b/tests/gcp/hooks/test_kubernetes_engine.py @@ -19,6 +19,8 @@ # import unittest +from google.cloud.container_v1.types import Cluster + from airflow import AirflowException from airflow.gcp.hooks.kubernetes_engine import GKEClusterHook from tests.compat import mock, PropertyMock @@ -39,7 +41,7 @@ class TestGKEClusterHookDelete(unittest.TestCase): new_callable=PropertyMock, return_value=None ) - @mock.patch("airflow.gcp.hooks.kubernetes_engine.GKEClusterHook._dict_to_proto") + @mock.patch("airflow.gcp.hooks.kubernetes_engine.ParseDict") @mock.patch( "airflow.gcp.hooks.kubernetes_engine.GKEClusterHook.wait_for_operation") def test_delete_cluster(self, wait_mock, convert_mock, mock_project_id): @@ -66,7 +68,7 @@ class TestGKEClusterHookDelete(unittest.TestCase): ) @mock.patch( "airflow.gcp.hooks.kubernetes_engine.GKEClusterHook.log") - @mock.patch("airflow.gcp.hooks.kubernetes_engine.GKEClusterHook._dict_to_proto") + @mock.patch("airflow.gcp.hooks.kubernetes_engine.ParseDict") @mock.patch( "airflow.gcp.hooks.kubernetes_engine.GKEClusterHook.wait_for_operation") def test_delete_cluster_not_found(self, wait_mock, convert_mock, log_mock, mock_project_id): @@ -85,7 +87,7 @@ class TestGKEClusterHookDelete(unittest.TestCase): new_callable=PropertyMock, return_value=None ) - @mock.patch("airflow.gcp.hooks.kubernetes_engine.GKEClusterHook._dict_to_proto") + @mock.patch("airflow.gcp.hooks.kubernetes_engine.ParseDict") @mock.patch( "airflow.gcp.hooks.kubernetes_engine.GKEClusterHook.wait_for_operation") def test_delete_cluster_error(self, wait_mock, convert_mock, mock_project_id): @@ -108,12 +110,10 @@ class TestGKEClusterHookCreate(unittest.TestCase): new_callable=PropertyMock, return_value=None ) - @mock.patch("airflow.gcp.hooks.kubernetes_engine.GKEClusterHook._dict_to_proto") + @mock.patch("airflow.gcp.hooks.kubernetes_engine.ParseDict") @mock.patch( "airflow.gcp.hooks.kubernetes_engine.GKEClusterHook.wait_for_operation") def test_create_cluster_proto(self, wait_mock, convert_mock, mock_project_id): - from google.cloud.container_v1.proto.cluster_service_pb2 import Cluster - mock_cluster_proto = Cluster() mock_cluster_proto.name = CLUSTER_NAME @@ -138,7 +138,7 @@ class TestGKEClusterHookCreate(unittest.TestCase): new_callable=PropertyMock, return_value=None ) - @mock.patch("airflow.gcp.hooks.kubernetes_engine.GKEClusterHook._dict_to_proto") + @mock.patch("airflow.gcp.hooks.kubernetes_engine.ParseDict") @mock.patch( "airflow.gcp.hooks.kubernetes_engine.GKEClusterHook.wait_for_operation") def test_create_cluster_dict(self, wait_mock, convert_mock, mock_project_id): @@ -158,9 +158,12 @@ class TestGKEClusterHookCreate(unittest.TestCase): cluster=proto_mock, retry=retry_mock, timeout=timeout_mock) wait_mock.assert_called_once_with(client_create.return_value) - self.assertEqual(convert_mock.call_args[1]['py_dict'], mock_cluster_dict) + convert_mock.assert_called_once_with( + {'name': 'test-cluster'}, + Cluster() + ) - @mock.patch("airflow.gcp.hooks.kubernetes_engine.GKEClusterHook._dict_to_proto") + @mock.patch("airflow.gcp.hooks.kubernetes_engine.ParseDict") @mock.patch( "airflow.gcp.hooks.kubernetes_engine.GKEClusterHook.wait_for_operation") def test_create_cluster_error(self, wait_mock, convert_mock): @@ -179,7 +182,7 @@ class TestGKEClusterHookCreate(unittest.TestCase): ) @mock.patch( "airflow.gcp.hooks.kubernetes_engine.GKEClusterHook.log") - @mock.patch("airflow.gcp.hooks.kubernetes_engine.GKEClusterHook._dict_to_proto") + @mock.patch("airflow.gcp.hooks.kubernetes_engine.ParseDict") @mock.patch( "airflow.gcp.hooks.kubernetes_engine.GKEClusterHook.wait_for_operation") def test_create_cluster_already_exists(self, wait_mock, convert_mock, log_mock, mock_project_id): @@ -291,16 +294,3 @@ class TestGKEClusterHook(unittest.TestCase): operation_mock.assert_any_call(running_op.name, project_id=TEST_GCP_PROJECT_ID) operation_mock.assert_any_call(pending_op.name, project_id=TEST_GCP_PROJECT_ID) self.assertEqual(operation_mock.call_count, 2) - - @mock.patch("google.protobuf.json_format.Parse") - @mock.patch("json.dumps") - def test_dict_to_proto(self, dumps_mock, parse_mock): - mock_dict = {'name': 'test'} - mock_proto = mock.Mock() - - dumps_mock.return_value = mock.Mock() - - self.gke_hook._dict_to_proto(mock_dict, mock_proto) - - dumps_mock.assert_called_once_with(mock_dict) - parse_mock.assert_called_once_with(dumps_mock(), mock_proto)