Changed config params to dictionary

This commit is contained in:
nonstoptimm 2022-01-19 16:04:24 +00:00
Родитель 6cec1f54a9
Коммит 2aebfadf61
5 изменённых файлов: 36 добавлений и 34 удалений

Просмотреть файл

@ -44,7 +44,7 @@ if __name__ == '__main__':
# Case Management
if any([do_scoring, do_synthesize, do_transcribe, do_evaluate]):
output_folder, case = he.create_case(pa.output_folder)
output_folder, case = he.create_case(pa.config_data['output_folder'])
logging.info(f'[INFO] - Created case {case}')
try:
os.makedirs(f"{output_folder}/{case}/input", exist_ok=True)
@ -100,7 +100,7 @@ if __name__ == '__main__':
# LUIS Scoring
if do_scoring:
logging.info('[INFO] - Starting with LUIS scoring')
logging.info(f'[INFO] - Set LUIS treshold to {pa.luis_treshold}')
logging.info(f'[INFO] - Set LUIS treshold to {pa.config_data["luis_treshold"]}')
if 'intent' in list(df_reference.columns) and not 'rec' in list(df_reference.columns):
luis_scoring = luis.main(df_reference, 'text')
elif all(['intent' in list(df_reference.columns), 'rec' in list(df_reference.columns)]):
@ -111,7 +111,7 @@ if __name__ == '__main__':
else:
logging.error('[ERROR] - Cannot do LUIS scoring, please verify that you have an "intent"-column in your data.')
# Write to output file
luis_scoring['luis_treshold'] = pa.luis_treshold
luis_scoring['luis_treshold'] = pa.config_data['luis_treshold']
luis_scoring.to_csv(f'{output_folder}/{case}/luis_scoring.csv', sep = ',', encoding = 'utf-8', index=False)
# Finish run

Просмотреть файл

@ -26,7 +26,7 @@ def request_luis(text):
# Uncomment this if you are using the old url version having the region name as endpoint.
# endpoint_url = f'{endpoint}.api.cognitive.microsoft.com'.
# Below, you see the most current version of the api having the prediction resource name as endpoint.
endpoint_url = f'{pa.luis_endpoint}.cognitiveservices.azure.com'
endpoint_url = f'{pa.config_data["luis_endpoint"]}.cognitiveservices.azure.com'
headers = {}
params = {
'query': text,
@ -35,9 +35,9 @@ def request_luis(text):
'show-all-intents': 'true',
'spellCheck': 'false',
'staging': 'false',
'subscription-key': pa.luis_key
'subscription-key': pa.config_data['luis_key']
}
r = requests.get(f'https://{endpoint_url}/luis/prediction/v3.0/apps/{pa.luis_appid}/slots/{pa.luis_slot}/predict', headers=headers, params=params)
r = requests.get(f'https://{endpoint_url}/luis/prediction/v3.0/apps/{pa.config_data["luis_appid"]}/slots/{pa.config_data["luis_slot"]}/predict', headers=headers, params=params)
# Check
logging.debug(json.dumps(json.loads(r.text), indent=2))
return r.json()
@ -53,7 +53,7 @@ def luis_classification_report(df, col):
logging.info('[INFO] - Starting to create classification report')
logging.info('[OUTPUT] - CLASSIFICATION REPORT (without reset by treshold):')
logging.info(classification_report(df['intent'], df[f'prediction_{col}']))
logging.info(f'[OUTPUT] - AFTER RESET BY TRESHOLD ({pa.luis_treshold}):')
logging.info(f'[OUTPUT] - AFTER RESET BY TRESHOLD ({pa.config_data["luis_treshold"]}):')
logging.info(classification_report(df['intent'], df[f'prediction_drop_{col}']))
logging.info('[OUTPUT] - CONFUSION MATRIX:')
logging.info(f'\n{confusion_matrix(df["intent"], df[f"prediction_{col}"])}')
@ -79,7 +79,7 @@ def main(df, col):
top_intent = data['prediction']['topIntent']
top_score = data['prediction']['intents'][top_intent]['score']
# Evaluate scores based on treshold and set None-intent if confidence is too low
if top_score < pa.luis_treshold:
if top_score < pa.config_data['luis_treshold']:
drop = "None"
else:
drop = top_intent

Просмотреть файл

@ -50,28 +50,30 @@ def get_config(fname_config='config.ini'):
# Get config file
sys.path.append('./')
config = configparser.ConfigParser()
global output_folder, driver, luis_appid, luis_key, luis_region, luis_endpoint, luis_slot, luis_treshold, stt_key, stt_endpoint, stt_region, tts_key, tts_region, tts_resource_name, tts_language, tts_font
global config_data
#global output_folder, driver, luis_appid, luis_key, luis_region, luis_endpoint, luis_slot, luis_treshold, stt_key, stt_endpoint, stt_region, tts_key, tts_region, tts_resource_name, tts_language, tts_font
config.read(fname_config)
# Read keys/values and assign it to variables
try:
output_folder = config['dir']['output_folder']
stt_key = config['stt']['key']
stt_endpoint = config['stt']['endpoint']
stt_region = config['stt']['region']
tts_key = config['tts']['key']
tts_region = config['tts']['region']
tts_resource_name = config['tts']['resource_name']
tts_language = config['tts']['language']
tts_font = config['tts']['font']
luis_appid = config['luis']['app_id']
luis_key = config['luis']['key']
luis_region = config['luis']['region']
luis_endpoint = config['luis']['endpoint']
luis_slot = config['luis']['slot']
luis_treshold = float(config['luis']['treshold'])
luis_treshold = 0 if luis_treshold == '' else luis_treshold
driver = config['driver']['path']
config_data = dict()
config_data['output_folder'] = config['dir']['output_folder']
config_data['stt_key'] = config['stt']['key']
config_data['stt_endpoint'] = config['stt']['endpoint']
config_data['stt_region'] = config['stt']['region']
config_data['tts_key'] = config['tts']['key']
config_data['tts_region'] = config['tts']['region']
config_data['tts_resource_name'] = config['tts']['resource_name']
config_data['tts_language'] = config['tts']['language']
config_data['tts_font'] = config['tts']['font']
config_data['luis_appid'] = config['luis']['app_id']
config_data['luis_key'] = config['luis']['key']
config_data['luis_region'] = config['luis']['region']
config_data['luis_endpoint'] = config['luis']['endpoint']
config_data['luis_slot'] = config['luis']['slot']
config_data['luis_treshold'] = float(config['luis']['treshold'])
config_data['luis_treshold'] = 0 if config_data['luis_treshold'] == '' else config_data['luis_treshold']
config_data['driver'] = config['driver']['path']
except KeyError as e:
logging.error(f'[ERROR] - Exit with KeyError for {e}, please verify structure and existance of your config.ini file. You may use config.sample.ini as guidance.')
sys.exit()

Просмотреть файл

@ -90,7 +90,7 @@ def main(speech_files, output_directory, lexical = False, enable_proxy = False,
zip(filenames, results): Zipped lists of filenames and STT-results as string
"""
try:
speech_config = speechsdk.SpeechConfig(subscription = pa.stt_key, region = pa.stt_region)
speech_config = speechsdk.SpeechConfig(subscription = pa.config_data['stt_key'], region = pa.config_data['stt_region'])
except RuntimeError:
logging.error("[ERROR] - Could not retrieve speech config")
# If necessary, you can enable a proxy here:
@ -99,8 +99,8 @@ def main(speech_files, output_directory, lexical = False, enable_proxy = False,
speech_config.set_proxy(argv[0], argv[1], argv[2], argv[3])
# Set speech service properties, requesting the detailed response format to make it compatible with lexical format, if wanted
speech_config.set_service_property(name='format', value='detailed', channel=speechsdk.ServicePropertyChannel.UriQueryParameter)
if pa.stt_endpoint != "":
speech_config.endpoint_id = pa.stt_endpoint
if pa.config_data['stt_endpoint'] != "":
speech_config.endpoint_id = pa.config_data['stt_endpoint']
logging.info(f'[INFO] - Starting to transcribe {len(next(os.walk(speech_files))[2])} audio files')
results = []
filenames = []

Просмотреть файл

@ -128,7 +128,7 @@ def main(df, output_directory, custom=True, telephone=True):
"""
# Check if it's Windows for driver import - if not, setting of driver is not necessary
if os.name == "nt":
AudioSegment.ffmpeg = pa.driver
AudioSegment.ffmpeg = pa.config_data['driver']
logging.debug("Running on Windows")
else:
logging.debug("Running on Linux")
@ -136,19 +136,19 @@ def main(df, output_directory, custom=True, telephone=True):
os.makedirs(f'{output_directory}/tts_generated/', exist_ok=True)
audio_synth = []
# Instantiate SpeechConfig for the entire run, as well as voice name and audio format
speech_config = SpeechConfig(subscription=pa.tts_key, region=pa.tts_region)
speech_config.speech_synthesis_voice_name = f'{pa.tts_language}-{pa.tts_font}'
speech_config = SpeechConfig(subscription=pa.config_data['tts_key'], region=pa.config_data['tts_region'])
speech_config.speech_synthesis_voice_name = f'{pa.config_data["tts_language"]}-{pa.config_data["tts_font"]}'
speech_config.set_speech_synthesis_output_format(SpeechSynthesisOutputFormat['Riff24Khz16BitMonoPcm'])
# Loop through dataframe of utterances
for index, row in df.iterrows():
# Submit request to TTS
try:
fname = f"{datetime.today().strftime('%Y-%m-%d')}_{pa.tts_language}_{pa.tts_font}_{str(uuid.uuid4().hex)}.wav"
fname = f"{datetime.today().strftime('%Y-%m-%d')}_{pa.config_data['tts_language']}_{pa.config_data['tts_font']}_{str(uuid.uuid4().hex)}.wav"
# AudioOutputConfig has to be set separately due to the file names
audio_config = AudioOutputConfig(filename=f'{output_directory}/tts_generated/{fname}')
synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config)
# Submit request and write outputs
synthesizer.speak_ssml_async(get_ssml_string(row['text'], pa.tts_language, pa.tts_font))
synthesizer.speak_ssml_async(get_ssml_string(row['text'], pa.config_data['tts_language'], pa.config_data['tts_font']))
except Exception as e:
logging.error(f'[ERROR] - Synthetization of "{row["text"]}" failed -> {e}')
audio_synth.append('nan')