зеркало из https://github.com/microsoft/glue.git
Changed config params to dictionary
This commit is contained in:
Родитель
6cec1f54a9
Коммит
2aebfadf61
|
@ -44,7 +44,7 @@ if __name__ == '__main__':
|
|||
|
||||
# Case Management
|
||||
if any([do_scoring, do_synthesize, do_transcribe, do_evaluate]):
|
||||
output_folder, case = he.create_case(pa.output_folder)
|
||||
output_folder, case = he.create_case(pa.config_data['output_folder'])
|
||||
logging.info(f'[INFO] - Created case {case}')
|
||||
try:
|
||||
os.makedirs(f"{output_folder}/{case}/input", exist_ok=True)
|
||||
|
@ -100,7 +100,7 @@ if __name__ == '__main__':
|
|||
# LUIS Scoring
|
||||
if do_scoring:
|
||||
logging.info('[INFO] - Starting with LUIS scoring')
|
||||
logging.info(f'[INFO] - Set LUIS treshold to {pa.luis_treshold}')
|
||||
logging.info(f'[INFO] - Set LUIS treshold to {pa.config_data["luis_treshold"]}')
|
||||
if 'intent' in list(df_reference.columns) and not 'rec' in list(df_reference.columns):
|
||||
luis_scoring = luis.main(df_reference, 'text')
|
||||
elif all(['intent' in list(df_reference.columns), 'rec' in list(df_reference.columns)]):
|
||||
|
@ -111,7 +111,7 @@ if __name__ == '__main__':
|
|||
else:
|
||||
logging.error('[ERROR] - Cannot do LUIS scoring, please verify that you have an "intent"-column in your data.')
|
||||
# Write to output file
|
||||
luis_scoring['luis_treshold'] = pa.luis_treshold
|
||||
luis_scoring['luis_treshold'] = pa.config_data['luis_treshold']
|
||||
luis_scoring.to_csv(f'{output_folder}/{case}/luis_scoring.csv', sep = ',', encoding = 'utf-8', index=False)
|
||||
|
||||
# Finish run
|
||||
|
|
|
@ -26,7 +26,7 @@ def request_luis(text):
|
|||
# Uncomment this if you are using the old url version having the region name as endpoint.
|
||||
# endpoint_url = f'{endpoint}.api.cognitive.microsoft.com'.
|
||||
# Below, you see the most current version of the api having the prediction resource name as endpoint.
|
||||
endpoint_url = f'{pa.luis_endpoint}.cognitiveservices.azure.com'
|
||||
endpoint_url = f'{pa.config_data["luis_endpoint"]}.cognitiveservices.azure.com'
|
||||
headers = {}
|
||||
params = {
|
||||
'query': text,
|
||||
|
@ -35,9 +35,9 @@ def request_luis(text):
|
|||
'show-all-intents': 'true',
|
||||
'spellCheck': 'false',
|
||||
'staging': 'false',
|
||||
'subscription-key': pa.luis_key
|
||||
'subscription-key': pa.config_data['luis_key']
|
||||
}
|
||||
r = requests.get(f'https://{endpoint_url}/luis/prediction/v3.0/apps/{pa.luis_appid}/slots/{pa.luis_slot}/predict', headers=headers, params=params)
|
||||
r = requests.get(f'https://{endpoint_url}/luis/prediction/v3.0/apps/{pa.config_data["luis_appid"]}/slots/{pa.config_data["luis_slot"]}/predict', headers=headers, params=params)
|
||||
# Check
|
||||
logging.debug(json.dumps(json.loads(r.text), indent=2))
|
||||
return r.json()
|
||||
|
@ -53,7 +53,7 @@ def luis_classification_report(df, col):
|
|||
logging.info('[INFO] - Starting to create classification report')
|
||||
logging.info('[OUTPUT] - CLASSIFICATION REPORT (without reset by treshold):')
|
||||
logging.info(classification_report(df['intent'], df[f'prediction_{col}']))
|
||||
logging.info(f'[OUTPUT] - AFTER RESET BY TRESHOLD ({pa.luis_treshold}):')
|
||||
logging.info(f'[OUTPUT] - AFTER RESET BY TRESHOLD ({pa.config_data["luis_treshold"]}):')
|
||||
logging.info(classification_report(df['intent'], df[f'prediction_drop_{col}']))
|
||||
logging.info('[OUTPUT] - CONFUSION MATRIX:')
|
||||
logging.info(f'\n{confusion_matrix(df["intent"], df[f"prediction_{col}"])}')
|
||||
|
@ -79,7 +79,7 @@ def main(df, col):
|
|||
top_intent = data['prediction']['topIntent']
|
||||
top_score = data['prediction']['intents'][top_intent]['score']
|
||||
# Evaluate scores based on treshold and set None-intent if confidence is too low
|
||||
if top_score < pa.luis_treshold:
|
||||
if top_score < pa.config_data['luis_treshold']:
|
||||
drop = "None"
|
||||
else:
|
||||
drop = top_intent
|
||||
|
|
|
@ -50,28 +50,30 @@ def get_config(fname_config='config.ini'):
|
|||
# Get config file
|
||||
sys.path.append('./')
|
||||
config = configparser.ConfigParser()
|
||||
global output_folder, driver, luis_appid, luis_key, luis_region, luis_endpoint, luis_slot, luis_treshold, stt_key, stt_endpoint, stt_region, tts_key, tts_region, tts_resource_name, tts_language, tts_font
|
||||
global config_data
|
||||
#global output_folder, driver, luis_appid, luis_key, luis_region, luis_endpoint, luis_slot, luis_treshold, stt_key, stt_endpoint, stt_region, tts_key, tts_region, tts_resource_name, tts_language, tts_font
|
||||
config.read(fname_config)
|
||||
|
||||
# Read keys/values and assign it to variables
|
||||
try:
|
||||
output_folder = config['dir']['output_folder']
|
||||
stt_key = config['stt']['key']
|
||||
stt_endpoint = config['stt']['endpoint']
|
||||
stt_region = config['stt']['region']
|
||||
tts_key = config['tts']['key']
|
||||
tts_region = config['tts']['region']
|
||||
tts_resource_name = config['tts']['resource_name']
|
||||
tts_language = config['tts']['language']
|
||||
tts_font = config['tts']['font']
|
||||
luis_appid = config['luis']['app_id']
|
||||
luis_key = config['luis']['key']
|
||||
luis_region = config['luis']['region']
|
||||
luis_endpoint = config['luis']['endpoint']
|
||||
luis_slot = config['luis']['slot']
|
||||
luis_treshold = float(config['luis']['treshold'])
|
||||
luis_treshold = 0 if luis_treshold == '' else luis_treshold
|
||||
driver = config['driver']['path']
|
||||
config_data = dict()
|
||||
config_data['output_folder'] = config['dir']['output_folder']
|
||||
config_data['stt_key'] = config['stt']['key']
|
||||
config_data['stt_endpoint'] = config['stt']['endpoint']
|
||||
config_data['stt_region'] = config['stt']['region']
|
||||
config_data['tts_key'] = config['tts']['key']
|
||||
config_data['tts_region'] = config['tts']['region']
|
||||
config_data['tts_resource_name'] = config['tts']['resource_name']
|
||||
config_data['tts_language'] = config['tts']['language']
|
||||
config_data['tts_font'] = config['tts']['font']
|
||||
config_data['luis_appid'] = config['luis']['app_id']
|
||||
config_data['luis_key'] = config['luis']['key']
|
||||
config_data['luis_region'] = config['luis']['region']
|
||||
config_data['luis_endpoint'] = config['luis']['endpoint']
|
||||
config_data['luis_slot'] = config['luis']['slot']
|
||||
config_data['luis_treshold'] = float(config['luis']['treshold'])
|
||||
config_data['luis_treshold'] = 0 if config_data['luis_treshold'] == '' else config_data['luis_treshold']
|
||||
config_data['driver'] = config['driver']['path']
|
||||
except KeyError as e:
|
||||
logging.error(f'[ERROR] - Exit with KeyError for {e}, please verify structure and existance of your config.ini file. You may use config.sample.ini as guidance.')
|
||||
sys.exit()
|
||||
|
|
|
@ -90,7 +90,7 @@ def main(speech_files, output_directory, lexical = False, enable_proxy = False,
|
|||
zip(filenames, results): Zipped lists of filenames and STT-results as string
|
||||
"""
|
||||
try:
|
||||
speech_config = speechsdk.SpeechConfig(subscription = pa.stt_key, region = pa.stt_region)
|
||||
speech_config = speechsdk.SpeechConfig(subscription = pa.config_data['stt_key'], region = pa.config_data['stt_region'])
|
||||
except RuntimeError:
|
||||
logging.error("[ERROR] - Could not retrieve speech config")
|
||||
# If necessary, you can enable a proxy here:
|
||||
|
@ -99,8 +99,8 @@ def main(speech_files, output_directory, lexical = False, enable_proxy = False,
|
|||
speech_config.set_proxy(argv[0], argv[1], argv[2], argv[3])
|
||||
# Set speech service properties, requesting the detailed response format to make it compatible with lexical format, if wanted
|
||||
speech_config.set_service_property(name='format', value='detailed', channel=speechsdk.ServicePropertyChannel.UriQueryParameter)
|
||||
if pa.stt_endpoint != "":
|
||||
speech_config.endpoint_id = pa.stt_endpoint
|
||||
if pa.config_data['stt_endpoint'] != "":
|
||||
speech_config.endpoint_id = pa.config_data['stt_endpoint']
|
||||
logging.info(f'[INFO] - Starting to transcribe {len(next(os.walk(speech_files))[2])} audio files')
|
||||
results = []
|
||||
filenames = []
|
||||
|
|
10
src/tts.py
10
src/tts.py
|
@ -128,7 +128,7 @@ def main(df, output_directory, custom=True, telephone=True):
|
|||
"""
|
||||
# Check if it's Windows for driver import - if not, setting of driver is not necessary
|
||||
if os.name == "nt":
|
||||
AudioSegment.ffmpeg = pa.driver
|
||||
AudioSegment.ffmpeg = pa.config_data['driver']
|
||||
logging.debug("Running on Windows")
|
||||
else:
|
||||
logging.debug("Running on Linux")
|
||||
|
@ -136,19 +136,19 @@ def main(df, output_directory, custom=True, telephone=True):
|
|||
os.makedirs(f'{output_directory}/tts_generated/', exist_ok=True)
|
||||
audio_synth = []
|
||||
# Instantiate SpeechConfig for the entire run, as well as voice name and audio format
|
||||
speech_config = SpeechConfig(subscription=pa.tts_key, region=pa.tts_region)
|
||||
speech_config.speech_synthesis_voice_name = f'{pa.tts_language}-{pa.tts_font}'
|
||||
speech_config = SpeechConfig(subscription=pa.config_data['tts_key'], region=pa.config_data['tts_region'])
|
||||
speech_config.speech_synthesis_voice_name = f'{pa.config_data["tts_language"]}-{pa.config_data["tts_font"]}'
|
||||
speech_config.set_speech_synthesis_output_format(SpeechSynthesisOutputFormat['Riff24Khz16BitMonoPcm'])
|
||||
# Loop through dataframe of utterances
|
||||
for index, row in df.iterrows():
|
||||
# Submit request to TTS
|
||||
try:
|
||||
fname = f"{datetime.today().strftime('%Y-%m-%d')}_{pa.tts_language}_{pa.tts_font}_{str(uuid.uuid4().hex)}.wav"
|
||||
fname = f"{datetime.today().strftime('%Y-%m-%d')}_{pa.config_data['tts_language']}_{pa.config_data['tts_font']}_{str(uuid.uuid4().hex)}.wav"
|
||||
# AudioOutputConfig has to be set separately due to the file names
|
||||
audio_config = AudioOutputConfig(filename=f'{output_directory}/tts_generated/{fname}')
|
||||
synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config)
|
||||
# Submit request and write outputs
|
||||
synthesizer.speak_ssml_async(get_ssml_string(row['text'], pa.tts_language, pa.tts_font))
|
||||
synthesizer.speak_ssml_async(get_ssml_string(row['text'], pa.config_data['tts_language'], pa.config_data['tts_font']))
|
||||
except Exception as e:
|
||||
logging.error(f'[ERROR] - Synthetization of "{row["text"]}" failed -> {e}')
|
||||
audio_synth.append('nan')
|
||||
|
|
Загрузка…
Ссылка в новой задаче