[AIRFLOW-3960] Adds Google Cloud Speech operators (#4780)
This commit is contained in:
Родитель
75c633e70f
Коммит
5c17948184
|
@ -0,0 +1,80 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Example Airflow DAG that runs speech synthesizing and stores output in Google Cloud Storage
|
||||
|
||||
This DAG relies on the following OS environment variables
|
||||
https://airflow.apache.org/concepts.html#variables
|
||||
* GCP_PROJECT_ID - Google Cloud Platform project for the Cloud SQL instance.
|
||||
* GCP_SPEECH_TEST_BUCKET - Name of the bucket in which the output file should be stored.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from airflow.utils import dates
|
||||
from airflow import models
|
||||
from airflow.contrib.operators.gcp_text_to_speech_operator import GcpTextToSpeechSynthesizeOperator
|
||||
from airflow.contrib.operators.gcp_speech_to_text_operator import GcpSpeechToTextRecognizeSpeechOperator
|
||||
|
||||
# [START howto_operator_text_to_speech_env_variables]
|
||||
GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
|
||||
BUCKET_NAME = os.environ.get("GCP_SPEECH_TEST_BUCKET", "gcp-speech-test-bucket")
|
||||
# [END howto_operator_text_to_speech_env_variables]
|
||||
|
||||
# [START howto_operator_text_to_speech_gcp_filename]
|
||||
FILENAME = "gcp-speech-test-file"
|
||||
# [END howto_operator_text_to_speech_gcp_filename]
|
||||
|
||||
# [START howto_operator_text_to_speech_api_arguments]
|
||||
INPUT = {"text": "This is just a test"}
|
||||
VOICE = {"language_code": "en-US", "ssml_gender": "FEMALE"}
|
||||
AUDIO_CONFIG = {"audio_encoding": "LINEAR16"}
|
||||
# [END howto_operator_text_to_speech_api_arguments]
|
||||
|
||||
# [START howto_operator_speech_to_text_api_arguments]
|
||||
CONFIG = {"encoding": "LINEAR16", "language_code": "en_US"}
|
||||
AUDIO = {"uri": "gs://{bucket}/{object}".format(bucket=BUCKET_NAME, object=FILENAME)}
|
||||
# [END howto_operator_speech_to_text_api_arguments]
|
||||
|
||||
default_args = {"start_date": dates.days_ago(1)}
|
||||
|
||||
with models.DAG(
|
||||
"example_gcp_speech", default_args=default_args, schedule_interval=None # Override to match your needs
|
||||
) as dag:
|
||||
|
||||
# [START howto_operator_text_to_speech_synthesize]
|
||||
text_to_speech_synthesize_task = GcpTextToSpeechSynthesizeOperator(
|
||||
project_id=GCP_PROJECT_ID,
|
||||
input_data=INPUT,
|
||||
voice=VOICE,
|
||||
audio_config=AUDIO_CONFIG,
|
||||
target_bucket_name=BUCKET_NAME,
|
||||
target_filename=FILENAME,
|
||||
task_id="text_to_speech_synthesize_task",
|
||||
)
|
||||
# [END howto_operator_text_to_speech_synthesize]
|
||||
|
||||
# [START howto_operator_speech_to_text_recognize]
|
||||
speech_to_text_recognize_task = GcpSpeechToTextRecognizeSpeechOperator(
|
||||
project_id=GCP_PROJECT_ID, config=CONFIG, audio=AUDIO, task_id="speech_to_text_recognize_task"
|
||||
)
|
||||
# [END howto_operator_speech_to_text_recognize]
|
||||
|
||||
text_to_speech_synthesize_task >> speech_to_text_recognize_task
|
|
@ -0,0 +1,73 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from google.cloud.speech_v1 import SpeechClient
|
||||
|
||||
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
|
||||
|
||||
|
||||
class GCPSpeechToTextHook(GoogleCloudBaseHook):
|
||||
"""
|
||||
Hook for Google Cloud Speech API.
|
||||
|
||||
:param gcp_conn_id: The connection ID to use when fetching connection info.
|
||||
:type gcp_conn_id: str
|
||||
:param delegate_to: The account to impersonate, if any.
|
||||
For this to work, the service account making the request must have
|
||||
domain-wide delegation enabled.
|
||||
:type delegate_to: str
|
||||
"""
|
||||
|
||||
_client = None
|
||||
|
||||
def __init__(self, gcp_conn_id="google_cloud_default", delegate_to=None):
|
||||
super(GCPSpeechToTextHook, self).__init__(gcp_conn_id, delegate_to)
|
||||
|
||||
def get_conn(self):
|
||||
"""
|
||||
Retrieves connection to Cloud Speech.
|
||||
|
||||
:return: Google Cloud Speech client object.
|
||||
:rtype: google.cloud.speech_v1.SpeechClient
|
||||
"""
|
||||
if not self._client:
|
||||
self._client = SpeechClient(credentials=self._get_credentials())
|
||||
return self._client
|
||||
|
||||
def recognize_speech(self, config, audio, retry=None, timeout=None):
|
||||
"""
|
||||
Recognizes audio input
|
||||
|
||||
:param config: information to the recognizer that specifies how to process the request.
|
||||
https://googleapis.github.io/google-cloud-python/latest/speech/gapic/v1/types.html#google.cloud.speech_v1.types.RecognitionConfig
|
||||
:type config: dict or google.cloud.speech_v1.types.RecognitionConfig
|
||||
:param audio: audio data to be recognized
|
||||
https://googleapis.github.io/google-cloud-python/latest/speech/gapic/v1/types.html#google.cloud.speech_v1.types.RecognitionAudio
|
||||
:type audio: dict or google.cloud.speech_v1.types.RecognitionAudio
|
||||
:param retry: (Optional) A retry object used to retry requests. If None is specified,
|
||||
requests will not be retried.
|
||||
:type retry: google.api_core.retry.Retry
|
||||
:param timeout: (Optional) The amount of time, in seconds, to wait for the request to complete.
|
||||
Note that if retry is specified, the timeout applies to each individual attempt.
|
||||
:type timeout: float
|
||||
"""
|
||||
client = self.get_conn()
|
||||
response = client.recognize(config=config, audio=audio, retry=retry, timeout=timeout)
|
||||
self.log.info("Recognised speech: %s" % response)
|
||||
return response
|
|
@ -0,0 +1,80 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from google.cloud.texttospeech_v1 import TextToSpeechClient
|
||||
|
||||
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
|
||||
|
||||
|
||||
class GCPTextToSpeechHook(GoogleCloudBaseHook):
|
||||
"""
|
||||
Hook for Google Cloud Text to Speech API.
|
||||
|
||||
:param gcp_conn_id: The connection ID to use when fetching connection info.
|
||||
:type gcp_conn_id: str
|
||||
:param delegate_to: The account to impersonate, if any.
|
||||
For this to work, the service account making the request must have
|
||||
domain-wide delegation enabled.
|
||||
:type delegate_to: str
|
||||
"""
|
||||
|
||||
_client = None
|
||||
|
||||
def __init__(self, gcp_conn_id="google_cloud_default", delegate_to=None):
|
||||
super(GCPTextToSpeechHook, self).__init__(gcp_conn_id, delegate_to)
|
||||
|
||||
def get_conn(self):
|
||||
"""
|
||||
Retrieves connection to Cloud Text to Speech.
|
||||
|
||||
:return: Google Cloud Text to Speech client object.
|
||||
:rtype: google.cloud.texttospeech_v1.TextToSpeechClient
|
||||
"""
|
||||
if not self._client:
|
||||
self._client = TextToSpeechClient(credentials=self._get_credentials())
|
||||
return self._client
|
||||
|
||||
def synthesize_speech(self, input_data, voice, audio_config, retry=None, timeout=None):
|
||||
"""
|
||||
Synthesizes text input
|
||||
|
||||
:param input_data: text input to be synthesized. See more:
|
||||
https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.SynthesisInput
|
||||
:type input_data: dict or google.cloud.texttospeech_v1.types.SynthesisInput
|
||||
:param voice: configuration of voice to be used in synthesis. See more:
|
||||
https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.VoiceSelectionParams
|
||||
:type voice: dict or google.cloud.texttospeech_v1.types.VoiceSelectionParams
|
||||
:param audio_config: configuration of the synthesized audio. See more:
|
||||
https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.AudioConfig
|
||||
:type audio_config: dict or google.cloud.texttospeech_v1.types.AudioConfig
|
||||
:return: SynthesizeSpeechResponse See more:
|
||||
https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.SynthesizeSpeechResponse
|
||||
:rtype: object
|
||||
:param retry: (Optional) A retry object used to retry requests. If None is specified,
|
||||
requests will not be retried.
|
||||
:type retry: google.api_core.retry.Retry
|
||||
:param timeout: (Optional) The amount of time, in seconds, to wait for the request to complete.
|
||||
Note that if retry is specified, the timeout applies to each individual attempt.
|
||||
:type timeout: float
|
||||
"""
|
||||
client = self.get_conn()
|
||||
self.log.info("Synthesizing input: %s" % input_data)
|
||||
return client.synthesize_speech(
|
||||
input_=input_data, voice=voice, audio_config=audio_config, retry=retry, timeout=timeout
|
||||
)
|
|
@ -0,0 +1,90 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from airflow import AirflowException
|
||||
from airflow.contrib.hooks.gcp_speech_to_text_hook import GCPSpeechToTextHook
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class GcpSpeechToTextRecognizeSpeechOperator(BaseOperator):
|
||||
"""
|
||||
Recognizes speech from audio file and returns it as text.
|
||||
|
||||
.. seealso::
|
||||
For more information on how to use this operator, take a look at the guide:
|
||||
:ref:`howto/operator:GcpSpeechToTextRecognizeSpeechOperator`
|
||||
|
||||
:param config: information to the recognizer that specifies how to process the request. See more:
|
||||
https://googleapis.github.io/google-cloud-python/latest/speech/gapic/v1/types.html#google.cloud.speech_v1.types.RecognitionConfig
|
||||
:type config: dict or google.cloud.speech_v1.types.RecognitionConfig
|
||||
:param audio: audio data to be recognized. See more:
|
||||
https://googleapis.github.io/google-cloud-python/latest/speech/gapic/v1/types.html#google.cloud.speech_v1.types.RecognitionAudio
|
||||
:type audio: dict or google.cloud.speech_v1.types.RecognitionAudio
|
||||
:param project_id: Optional, Google Cloud Platform Project ID where the Compute
|
||||
Engine Instance exists. If set to None or missing, the default project_id from the GCP connection is
|
||||
used.
|
||||
:type project_id: str
|
||||
:param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud
|
||||
Platform. Defaults to 'google_cloud_default'.
|
||||
:type gcp_conn_id: str
|
||||
:param retry: (Optional) A retry object used to retry requests. If None is specified,
|
||||
requests will not be retried.
|
||||
:type retry: google.api_core.retry.Retry
|
||||
:param timeout: (Optional) The amount of time, in seconds, to wait for the request to complete.
|
||||
Note that if retry is specified, the timeout applies to each individual attempt.
|
||||
:type timeout: float
|
||||
"""
|
||||
|
||||
# [START gcp_speech_to_text_synthesize_template_fields]
|
||||
template_fields = ("audio", "config", "project_id", "gcp_conn_id", "timeout")
|
||||
# [END gcp_speech_to_text_synthesize_template_fields]
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
audio,
|
||||
config,
|
||||
project_id=None,
|
||||
gcp_conn_id="google_cloud_default",
|
||||
retry=None,
|
||||
timeout=None,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
self.audio = audio
|
||||
self.config = config
|
||||
self.project_id = project_id
|
||||
self.gcp_conn_id = gcp_conn_id
|
||||
self.retry = retry
|
||||
self.timeout = timeout
|
||||
self._validate_inputs()
|
||||
super(GcpSpeechToTextRecognizeSpeechOperator, self).__init__(*args, **kwargs)
|
||||
|
||||
def _validate_inputs(self):
|
||||
if self.audio == "":
|
||||
raise AirflowException("The required parameter 'audio' is empty")
|
||||
if self.config == "":
|
||||
raise AirflowException("The required parameter 'config' is empty")
|
||||
|
||||
def execute(self, context):
|
||||
_hook = GCPSpeechToTextHook(gcp_conn_id=self.gcp_conn_id)
|
||||
return _hook.recognize_speech(
|
||||
config=self.config, audio=self.audio, retry=self.retry, timeout=self.timeout
|
||||
)
|
|
@ -0,0 +1,129 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from airflow import AirflowException
|
||||
from airflow.contrib.hooks.gcp_text_to_speech_hook import GCPTextToSpeechHook
|
||||
from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
|
||||
from airflow.models import BaseOperator
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class GcpTextToSpeechSynthesizeOperator(BaseOperator):
|
||||
"""
|
||||
Synthesizes text to speech and stores it in Google Cloud Storage
|
||||
|
||||
.. seealso::
|
||||
For more information on how to use this operator, take a look at the guide:
|
||||
:ref:`howto/operator:GcpTextToSpeechSynthesizeOperator`
|
||||
|
||||
:param input_data: text input to be synthesized. See more:
|
||||
https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.SynthesisInput
|
||||
:type input_data: dict or google.cloud.texttospeech_v1.types.SynthesisInput
|
||||
:param voice: configuration of voice to be used in synthesis. See more:
|
||||
https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.VoiceSelectionParams
|
||||
:type voice: dict or google.cloud.texttospeech_v1.types.VoiceSelectionParams
|
||||
:param audio_config: configuration of the synthesized audio. See more:
|
||||
https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.AudioConfig
|
||||
:type audio_config: dict or google.cloud.texttospeech_v1.types.AudioConfig
|
||||
:param target_bucket_name: name of the GCS bucket in which output file should be stored
|
||||
:type target_bucket_name: str
|
||||
:param target_filename: filename of the output file.
|
||||
:type target_filename: str
|
||||
:param project_id: Optional, Google Cloud Platform Project ID where the Compute
|
||||
Engine Instance exists. If set to None or missing, the default project_id from the GCP connection is
|
||||
used.
|
||||
:type project_id: str
|
||||
:param gcp_conn_id: Optional, The connection ID used to connect to Google Cloud
|
||||
Platform. Defaults to 'google_cloud_default'.
|
||||
:type gcp_conn_id: str
|
||||
:param retry: (Optional) A retry object used to retry requests. If None is specified,
|
||||
requests will not be retried.
|
||||
:type retry: google.api_core.retry.Retry
|
||||
:param timeout: (Optional) The amount of time, in seconds, to wait for the request to complete.
|
||||
Note that if retry is specified, the timeout applies to each individual attempt.
|
||||
:type timeout: float
|
||||
"""
|
||||
|
||||
# [START gcp_text_to_speech_synthesize_template_fields]
|
||||
template_fields = (
|
||||
"input_data",
|
||||
"voice",
|
||||
"audio_config",
|
||||
"project_id",
|
||||
"gcp_conn_id",
|
||||
"target_bucket_name",
|
||||
"target_filename",
|
||||
)
|
||||
# [END gcp_text_to_speech_synthesize_template_fields]
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
input_data,
|
||||
voice,
|
||||
audio_config,
|
||||
target_bucket_name,
|
||||
target_filename,
|
||||
project_id=None,
|
||||
gcp_conn_id="google_cloud_default",
|
||||
retry=None,
|
||||
timeout=None,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
self.input_data = input_data
|
||||
self.voice = voice
|
||||
self.audio_config = audio_config
|
||||
self.target_bucket_name = target_bucket_name
|
||||
self.target_filename = target_filename
|
||||
self.project_id = project_id
|
||||
self.gcp_conn_id = gcp_conn_id
|
||||
self.retry = retry
|
||||
self.timeout = timeout
|
||||
self._validate_inputs()
|
||||
super(GcpTextToSpeechSynthesizeOperator, self).__init__(*args, **kwargs)
|
||||
|
||||
def _validate_inputs(self):
|
||||
for parameter in [
|
||||
"input_data",
|
||||
"voice",
|
||||
"audio_config",
|
||||
"target_bucket_name",
|
||||
"target_filename",
|
||||
]:
|
||||
if getattr(self, parameter) == "":
|
||||
raise AirflowException("The required parameter '{}' is empty".format(parameter))
|
||||
|
||||
def execute(self, context):
|
||||
gcp_text_to_speech_hook = GCPTextToSpeechHook(gcp_conn_id=self.gcp_conn_id)
|
||||
result = gcp_text_to_speech_hook.synthesize_speech(
|
||||
input_data=self.input_data,
|
||||
voice=self.voice,
|
||||
audio_config=self.audio_config,
|
||||
retry=self.retry,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
with NamedTemporaryFile() as temp_file:
|
||||
temp_file.write(result.audio_content)
|
||||
cloud_storage_hook = GoogleCloudStorageHook(google_cloud_storage_conn_id=self.gcp_conn_id)
|
||||
cloud_storage_hook.upload(
|
||||
bucket=self.target_bucket_name, object=self.target_filename, filename=temp_file.name
|
||||
)
|
|
@ -0,0 +1,125 @@
|
|||
.. Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
.. http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
.. Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
|
||||
Google Cloud Text to Speech Operators
|
||||
=====================================
|
||||
|
||||
.. _howto/operator:GcpTextToSpeechSynthesizeOperator:
|
||||
|
||||
GcpTextToSpeechSynthesizeOperator
|
||||
---------------------------------
|
||||
|
||||
Synthesizes text to audio file and stores it to Google Cloud Storage
|
||||
|
||||
For parameter definition, take a look at
|
||||
:class:`airflow.contrib.operators.gcp_text_to_speech_operator.GcpTextToSpeechSynthesizeOperator`
|
||||
|
||||
Arguments
|
||||
"""""""""
|
||||
|
||||
Some arguments in the example DAG are taken from the OS environment variables:
|
||||
|
||||
.. literalinclude:: ../../../../airflow/contrib/example_dags/example_gcp_speech.py
|
||||
:language: python
|
||||
:start-after: [START howto_operator_text_to_speech_env_variables]
|
||||
:end-before: [END howto_operator_text_to_speech_env_variables]
|
||||
|
||||
input, voice and audio_config arguments need to be dicts or objects of corresponding classes from
|
||||
google.cloud.texttospeech_v1.types module
|
||||
|
||||
for more information, see: https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/api.html#google.cloud.texttospeech_v1.TextToSpeechClient.synthesize_speech
|
||||
|
||||
.. literalinclude:: ../../../../airflow/contrib/example_dags/example_gcp_speech.py
|
||||
:language: python
|
||||
:start-after: [START howto_operator_text_to_speech_api_arguments]
|
||||
:end-before: [END howto_operator_text_to_speech_api_arguments]
|
||||
|
||||
filename is a simple string argument:
|
||||
|
||||
.. literalinclude:: ../../../../airflow/contrib/example_dags/example_gcp_speech.py
|
||||
:language: python
|
||||
:start-after: [START howto_operator_text_to_speech_gcp_filename]
|
||||
:end-before: [END howto_operator_text_to_speech_gcp_filename]
|
||||
|
||||
Using the operator
|
||||
""""""""""""""""""
|
||||
|
||||
.. literalinclude:: ../../../../airflow/contrib/example_dags/example_gcp_speech.py
|
||||
:language: python
|
||||
:dedent: 4
|
||||
:start-after: [START howto_operator_text_to_speech_synthesize]
|
||||
:end-before: [END howto_operator_text_to_speech_synthesize]
|
||||
|
||||
Templating
|
||||
""""""""""
|
||||
|
||||
.. literalinclude:: ../../../../airflow/contrib/operators/gcp_text_to_speech_operator.py
|
||||
:language: python
|
||||
:dedent: 4
|
||||
:start-after: [START gcp_text_to_speech_synthesize_template_fields]
|
||||
:end-before: [END gcp_text_to_speech_synthesize_template_fields]
|
||||
|
||||
Google Cloud Speech to Text Operators
|
||||
=====================================
|
||||
|
||||
.. _howto/operator:GcpSpeechToTextRecognizeSpeechOperator:
|
||||
|
||||
GcpSpeechToTextRecognizeSpeechOperator
|
||||
--------------------------------------
|
||||
|
||||
Recognizes speech in audio input and returns text.
|
||||
|
||||
For parameter definition, take a look at
|
||||
:class:`airflow.contrib.operators.gcp_speech_to_text_operator.GcpSpeechToTextRecognizeSpeechOperator`
|
||||
|
||||
Arguments
|
||||
"""""""""
|
||||
|
||||
config and audio arguments need to be dicts or objects of corresponding classes from
|
||||
google.cloud.speech_v1.types module
|
||||
|
||||
for more information, see: https://googleapis.github.io/google-cloud-python/latest/speech/gapic/v1/api.html#google.cloud.speech_v1.SpeechClient.recognize
|
||||
|
||||
.. literalinclude:: ../../../../airflow/contrib/example_dags/example_gcp_speech.py
|
||||
:language: python
|
||||
:start-after: [START howto_operator_text_to_speech_api_arguments]
|
||||
:end-before: [END howto_operator_text_to_speech_api_arguments]
|
||||
|
||||
filename is a simple string argument:
|
||||
|
||||
.. literalinclude:: ../../../../airflow/contrib/example_dags/example_gcp_speech.py
|
||||
:language: python
|
||||
:start-after: [START howto_operator_speech_to_text_api_arguments]
|
||||
:end-before: [END howto_operator_speech_to_text_api_arguments]
|
||||
|
||||
Using the operator
|
||||
""""""""""""""""""
|
||||
|
||||
.. literalinclude:: ../../../../airflow/contrib/example_dags/example_gcp_speech.py
|
||||
:language: python
|
||||
:dedent: 4
|
||||
:start-after: [START howto_operator_speech_to_text_recognize]
|
||||
:end-before: [END howto_operator_speech_to_text_recognize]
|
||||
|
||||
Templating
|
||||
""""""""""
|
||||
|
||||
.. literalinclude:: ../../../../airflow/contrib/operators/gcp_speech_to_text_operator.py
|
||||
:language: python
|
||||
:dedent: 4
|
||||
:start-after: [START gcp_speech_to_text_synthesize_template_fields]
|
||||
:end-before: [END gcp_speech_to_text_synthesize_template_fields]
|
|
@ -649,6 +649,23 @@ Cloud Vision Product Search Operators
|
|||
|
||||
They also use :class:`airflow.contrib.hooks.gcp_vision_hook.CloudVisionHook` to communicate with Google Cloud Platform.
|
||||
|
||||
Cloud Text to Speech
|
||||
''''''''''''''''''''
|
||||
|
||||
:class:`airflow.contrib.operators.gcp_text_to_speech_operator.GcpTextToSpeechSynthesizeOperator`
|
||||
Synthesizes input text into audio file and stores this file to GCS.
|
||||
|
||||
They also use :class:`airflow.contrib.hooks.gcp_text_to_speech_hook.GCPTextToSpeechHook` to communicate with Google Cloud Platform.
|
||||
|
||||
Cloud Speech to Text
|
||||
''''''''''''''''''''
|
||||
|
||||
:class:`airflow.contrib.operators.gcp_speech_to_text_operator.GcpSpeechToTextRecognizeSpeechOperator`
|
||||
Recognizes speech in audio input and returns text.
|
||||
|
||||
They also use :class:`airflow.contrib.hooks.gcp_speech_to_text_hook.GCPSpeechToTextHook` to communicate with Google Cloud Platform.
|
||||
|
||||
|
||||
Cloud Translate
|
||||
'''''''''''''''
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -183,6 +183,8 @@ gcp = [
|
|||
'google-cloud-spanner>=1.7.1',
|
||||
'google-cloud-translate>=1.3.3',
|
||||
'google-cloud-vision>=0.35.2',
|
||||
'google-cloud-texttospeech>=0.4.0',
|
||||
'google-cloud-speech>=0.36.3',
|
||||
'grpcio-gcp>=0.2.2',
|
||||
'PyOpenSSL',
|
||||
'pandas-gbq'
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from airflow.contrib.hooks.gcp_speech_to_text_hook import GCPSpeechToTextHook
|
||||
from tests.contrib.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id
|
||||
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError: # pragma: no cover
|
||||
try:
|
||||
import mock
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
PROJECT_ID = "project-id"
|
||||
CONFIG = {"ecryption": "LINEAR16"}
|
||||
AUDIO = {"uri": "gs://bucket/object"}
|
||||
|
||||
|
||||
class TestTextToSpeechOperator(unittest.TestCase):
|
||||
def setUp(self):
|
||||
with mock.patch(
|
||||
"airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook.__init__",
|
||||
new=mock_base_gcp_hook_default_project_id,
|
||||
):
|
||||
self.gcp_speech_to_text_hook = GCPSpeechToTextHook(gcp_conn_id="test")
|
||||
|
||||
@mock.patch("airflow.contrib.hooks.gcp_speech_to_text_hook.GCPSpeechToTextHook.get_conn")
|
||||
def test_synthesize_speech(self, get_conn):
|
||||
recognize_method = get_conn.return_value.recognize
|
||||
recognize_method.return_value = None
|
||||
self.gcp_speech_to_text_hook.recognize_speech(config=CONFIG, audio=AUDIO)
|
||||
recognize_method.assert_called_once_with(config=CONFIG, audio=AUDIO, retry=None, timeout=None)
|
|
@ -0,0 +1,56 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from airflow.contrib.hooks.gcp_text_to_speech_hook import GCPTextToSpeechHook
|
||||
from tests.contrib.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id
|
||||
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError: # pragma: no cover
|
||||
try:
|
||||
import mock
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
INPUT = {"text": "test text"}
|
||||
VOICE = {"language_code": "en-US", "ssml_gender": "FEMALE"}
|
||||
AUDIO_CONFIG = {"audio_encoding": "MP3"}
|
||||
|
||||
|
||||
class TestTextToSpeechHook(unittest.TestCase):
|
||||
def setUp(self):
|
||||
with mock.patch(
|
||||
"airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook.__init__",
|
||||
new=mock_base_gcp_hook_default_project_id,
|
||||
):
|
||||
self.gcp_text_to_speech_hook = GCPTextToSpeechHook(gcp_conn_id="test")
|
||||
|
||||
@mock.patch("airflow.contrib.hooks.gcp_text_to_speech_hook.GCPTextToSpeechHook.get_conn")
|
||||
def test_synthesize_speech(self, get_conn):
|
||||
synthesize_method = get_conn.return_value.synthesize_speech
|
||||
synthesize_method.return_value = None
|
||||
self.gcp_text_to_speech_hook.synthesize_speech(
|
||||
input_data=INPUT, voice=VOICE, audio_config=AUDIO_CONFIG
|
||||
)
|
||||
synthesize_method.assert_called_once_with(
|
||||
input_=INPUT, voice=VOICE, audio_config=AUDIO_CONFIG, retry=None, timeout=None
|
||||
)
|
Двоичный файл не отображается.
|
@ -0,0 +1,49 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from tests.contrib.utils.base_gcp_system_test_case import SKIP_TEST_WARNING, DagGcpSystemTestCase
|
||||
from tests.contrib.utils.gcp_authenticator import GCP_GCS_KEY
|
||||
|
||||
from tests.contrib.operators.test_gcp_speech_operator_system_helper import GCPTextToSpeechTestHelper
|
||||
|
||||
|
||||
@unittest.skipIf(DagGcpSystemTestCase.skip_check(GCP_GCS_KEY), SKIP_TEST_WARNING)
|
||||
class GCPTextToSpeechExampleDagSystemTest(DagGcpSystemTestCase):
|
||||
def setUp(self):
|
||||
super(GCPTextToSpeechExampleDagSystemTest, self).setUp()
|
||||
self.gcp_authenticator.gcp_authenticate()
|
||||
self.helper.create_target_bucket()
|
||||
self.gcp_authenticator.gcp_revoke_authentication()
|
||||
|
||||
def tearDown(self):
|
||||
self.gcp_authenticator.gcp_authenticate()
|
||||
self.helper.delete_target_bucket()
|
||||
self.gcp_authenticator.gcp_revoke_authentication()
|
||||
super(GCPTextToSpeechExampleDagSystemTest, self).tearDown()
|
||||
|
||||
def __init__(self, method_name="runTest"):
|
||||
super(GCPTextToSpeechExampleDagSystemTest, self).__init__(
|
||||
method_name, dag_id="example_gcp_speech", gcp_key=GCP_GCS_KEY
|
||||
)
|
||||
self.helper = GCPTextToSpeechTestHelper()
|
||||
|
||||
def test_run_example_dag_gcp_text_to_speech(self):
|
||||
self._run_dag()
|
|
@ -0,0 +1,75 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from tests.contrib.utils.base_gcp_system_test_case import RetrieveVariables
|
||||
from tests.contrib.utils.gcp_authenticator import GcpAuthenticator, GCP_GCS_KEY
|
||||
from tests.contrib.utils.logging_command_executor import LoggingCommandExecutor
|
||||
|
||||
|
||||
retrieve_variables = RetrieveVariables()
|
||||
|
||||
SERVICE_EMAIL_FORMAT = "project-%s@storage-transfer-service.iam.gserviceaccount.com"
|
||||
|
||||
GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
|
||||
TARGET_BUCKET_NAME = os.environ.get("GCP_SPEECH_TEST_BUCKET", "gcp-speech-test-bucket")
|
||||
|
||||
|
||||
class GCPTextToSpeechTestHelper(LoggingCommandExecutor):
|
||||
def create_target_bucket(self):
|
||||
self.execute_cmd(["gsutil", "mb", "-p", GCP_PROJECT_ID, "gs://%s/" % TARGET_BUCKET_NAME])
|
||||
|
||||
def delete_target_bucket(self):
|
||||
self.execute_cmd(["gsutil", "rm", "-r", "gs://%s/" % TARGET_BUCKET_NAME], True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Create and delete bucket for system tests.")
|
||||
parser.add_argument(
|
||||
"--action",
|
||||
dest="action",
|
||||
required=True,
|
||||
choices=("create-target-bucket", "delete-target-bucket", "before-tests", "after-tests"),
|
||||
)
|
||||
action = parser.parse_args().action
|
||||
|
||||
helper = GCPTextToSpeechTestHelper()
|
||||
gcp_authenticator = GcpAuthenticator(GCP_GCS_KEY)
|
||||
helper.log.info("Starting action: {}".format(action))
|
||||
|
||||
gcp_authenticator.gcp_store_authentication()
|
||||
try:
|
||||
gcp_authenticator.gcp_authenticate()
|
||||
if action == "before-tests":
|
||||
helper.create_target_bucket()
|
||||
elif action == "after-tests":
|
||||
helper.delete_target_bucket()
|
||||
elif action == "create-target-bucket":
|
||||
helper.create_target_bucket()
|
||||
elif action == "delete-target-bucket":
|
||||
helper.delete_target_bucket()
|
||||
else:
|
||||
raise Exception("Unknown action: {}".format(action))
|
||||
finally:
|
||||
gcp_authenticator.gcp_restore_authentication()
|
||||
|
||||
helper.log.info("Finishing action: {}".format(action))
|
|
@ -0,0 +1,78 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from airflow import AirflowException
|
||||
from airflow.contrib.operators.gcp_speech_to_text_operator import GcpSpeechToTextRecognizeSpeechOperator
|
||||
|
||||
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
try:
|
||||
import mock
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
PROJECT_ID = "project-id"
|
||||
GCP_CONN_ID = "gcp-conn-id"
|
||||
CONFIG = {"encoding": "LINEAR16"}
|
||||
AUDIO = {"uri": "gs://bucket/object"}
|
||||
|
||||
|
||||
class CloudSqlTest(unittest.TestCase):
|
||||
@mock.patch("airflow.contrib.operators.gcp_speech_to_text_operator.GCPSpeechToTextHook")
|
||||
def test_recognize_speech_green_path(self, mock_hook):
|
||||
mock_hook.return_value.recognize_speech.return_value = True
|
||||
|
||||
GcpSpeechToTextRecognizeSpeechOperator(
|
||||
project_id=PROJECT_ID, gcp_conn_id=GCP_CONN_ID, config=CONFIG, audio=AUDIO, task_id="id"
|
||||
).execute(context={"task_instance": mock.Mock()})
|
||||
|
||||
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID)
|
||||
mock_hook.return_value.recognize_speech.assert_called_once_with(
|
||||
config=CONFIG, audio=AUDIO, retry=None, timeout=None
|
||||
)
|
||||
|
||||
@mock.patch("airflow.contrib.operators.gcp_speech_to_text_operator.GCPSpeechToTextHook")
|
||||
def test_missing_config(self, mock_hook):
|
||||
mock_hook.return_value.recognize_speech.return_value = True
|
||||
|
||||
with self.assertRaises(AirflowException) as e:
|
||||
GcpSpeechToTextRecognizeSpeechOperator(
|
||||
project_id=PROJECT_ID, gcp_conn_id=GCP_CONN_ID, audio=AUDIO, task_id="id"
|
||||
).execute(context={"task_instance": mock.Mock()})
|
||||
|
||||
err = e.exception
|
||||
self.assertIn("config", str(err))
|
||||
mock_hook.assert_not_called()
|
||||
|
||||
@mock.patch("airflow.contrib.operators.gcp_speech_to_text_operator.GCPSpeechToTextHook")
|
||||
def test_missing_audio(self, mock_hook):
|
||||
mock_hook.return_value.recognize_speech.return_value = True
|
||||
|
||||
with self.assertRaises(AirflowException) as e:
|
||||
GcpSpeechToTextRecognizeSpeechOperator(
|
||||
project_id=PROJECT_ID, gcp_conn_id=GCP_CONN_ID, config=CONFIG, task_id="id"
|
||||
).execute(context={"task_instance": mock.Mock()})
|
||||
|
||||
err = e.exception
|
||||
self.assertIn("audio", str(err))
|
||||
mock_hook.assert_not_called()
|
|
@ -0,0 +1,112 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from airflow import AirflowException
|
||||
|
||||
from airflow.contrib.operators.gcp_text_to_speech_operator import GcpTextToSpeechSynthesizeOperator
|
||||
|
||||
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
try:
|
||||
import mock
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
PROJECT_ID = "project-id"
|
||||
GCP_CONN_ID = "gcp-conn-id"
|
||||
INPUT = {"text": "text"}
|
||||
VOICE = {"language_code": "en-US"}
|
||||
AUDIO_CONFIG = {"audio_encoding": "MP3"}
|
||||
TARGET_BUCKET_NAME = "target_bucket_name"
|
||||
TARGET_FILENAME = "target_filename"
|
||||
|
||||
|
||||
class GcpTextToSpeechTest(unittest.TestCase):
|
||||
@mock.patch("airflow.contrib.operators.gcp_text_to_speech_operator.GoogleCloudStorageHook")
|
||||
@mock.patch("airflow.contrib.operators.gcp_text_to_speech_operator.GCPTextToSpeechHook")
|
||||
def test_synthesize_text_green_path(self, mock_text_to_speech_hook, mock_gcp_hook):
|
||||
mocked_response = mock.Mock()
|
||||
type(mocked_response).audio_content = mock.PropertyMock(return_value=b"audio")
|
||||
|
||||
mock_text_to_speech_hook.return_value.synthesize_speech.return_value = mocked_response
|
||||
mock_gcp_hook.return_value.upload.return_value = True
|
||||
|
||||
GcpTextToSpeechSynthesizeOperator(
|
||||
project_id=PROJECT_ID,
|
||||
gcp_conn_id=GCP_CONN_ID,
|
||||
input_data=INPUT,
|
||||
voice=VOICE,
|
||||
audio_config=AUDIO_CONFIG,
|
||||
target_bucket_name=TARGET_BUCKET_NAME,
|
||||
target_filename=TARGET_FILENAME,
|
||||
task_id="id",
|
||||
).execute(context={"task_instance": mock.Mock()})
|
||||
|
||||
mock_text_to_speech_hook.assert_called_once_with(gcp_conn_id="gcp-conn-id")
|
||||
mock_gcp_hook.assert_called_once_with(google_cloud_storage_conn_id="gcp-conn-id")
|
||||
mock_text_to_speech_hook.return_value.synthesize_speech.assert_called_once_with(
|
||||
input_data=INPUT, voice=VOICE, audio_config=AUDIO_CONFIG, retry=None, timeout=None
|
||||
)
|
||||
mock_gcp_hook.return_value.upload.assert_called_once_with(
|
||||
bucket=TARGET_BUCKET_NAME, object=TARGET_FILENAME, filename=mock.ANY
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("input_data", "", VOICE, AUDIO_CONFIG, TARGET_BUCKET_NAME, TARGET_FILENAME),
|
||||
("voice", INPUT, "", AUDIO_CONFIG, TARGET_BUCKET_NAME, TARGET_FILENAME),
|
||||
("audio_config", INPUT, VOICE, "", TARGET_BUCKET_NAME, TARGET_FILENAME),
|
||||
("target_bucket_name", INPUT, VOICE, AUDIO_CONFIG, "", TARGET_FILENAME),
|
||||
("target_filename", INPUT, VOICE, AUDIO_CONFIG, TARGET_BUCKET_NAME, ""),
|
||||
]
|
||||
)
|
||||
@mock.patch("airflow.contrib.operators.gcp_text_to_speech_operator.GoogleCloudStorageHook")
|
||||
@mock.patch("airflow.contrib.operators.gcp_text_to_speech_operator.GCPTextToSpeechHook")
|
||||
def test_missing_arguments(
|
||||
self,
|
||||
missing_arg,
|
||||
input_data,
|
||||
voice,
|
||||
audio_config,
|
||||
target_bucket_name,
|
||||
target_filename,
|
||||
mock_text_to_speech_hook,
|
||||
mock_gcp_hook,
|
||||
):
|
||||
with self.assertRaises(AirflowException) as e:
|
||||
GcpTextToSpeechSynthesizeOperator(
|
||||
project_id="project-id",
|
||||
input_data=input_data,
|
||||
voice=voice,
|
||||
audio_config=audio_config,
|
||||
target_bucket_name=target_bucket_name,
|
||||
target_filename=target_filename,
|
||||
task_id="id",
|
||||
).execute(context={"task_instance": mock.Mock()})
|
||||
|
||||
err = e.exception
|
||||
self.assertIn(missing_arg, str(err))
|
||||
mock_text_to_speech_hook.assert_not_called()
|
||||
mock_gcp_hook.assert_not_called()
|
Загрузка…
Ссылка в новой задаче