Remove scorer after it's been used but only if it was generated

This commit is contained in:
Bias 2020-06-26 12:59:34 +02:00 коммит произвёл Tilman Kamp
Родитель 3dc0fbb44c
Коммит 6b7270c96a
1 изменённых файлов: 10 добавлений и 3 удалений

Просмотреть файл

@ -420,6 +420,7 @@ def main():
output_graph_path = None
for audio_path, tlog_path, script_path, aligned_path in to_prepare:
if not exists(tlog_path):
generated_scorer = False
if output_graph_path is None:
logging.debug('Looking for model files in "{}"...'.format(model_dir))
output_graph_path = glob(model_dir + "/*.pbmm")[0]
@ -453,6 +454,8 @@ def main():
create_bundle(alphabet_path, scorer_path + '.' + 'lm.binary', scorer_path + '.' + 'vocab-500000.txt', scorer_path, False, 0.931289039105002, 1.1834137581510284)
os.remove(scorer_path + '.' + 'lm.binary')
os.remove(scorer_path + '.' + 'vocab-500000.txt')
generated_scorer = True
else:
scorer_path = lang_scorer_path
@ -506,6 +509,10 @@ def main():
logging.debug('Writing transcription log to file "{}"...'.format(tlog_path))
with open(tlog_path, 'w', encoding='utf-8') as tlog_file:
tlog_file.write(json.dumps(fragments, indent=4 if args.output_pretty else None, ensure_ascii=False))
# Remove scorer if generated
if generated_scorer:
os.remove(scorer_path)
if not path.isfile(tlog_path):
fail('Problem loading transcript from "{}"'.format(tlog_path))
to_align.append((tlog_path, script_path, aligned_path))
@ -584,13 +591,13 @@ def parse_args():
stt_group.add_argument('--stt-model-rate', type=int, default=DEFAULT_RATE,
help='Supported sample rate of the acoustic model')
stt_group.add_argument('--stt-model-dir', required=False,
help='Path to a directory with output_graph, lm, trie and (optional) alphabet file ' +
'(default: "data/en"')
help='Path to a directory with output_graph, scorer and (optional) alphabet file ' +
'(default: "models/en"')
stt_group.add_argument('--stt-no-own-lm', action="store_true",
help='Deactivates creation of individual language models per document.' +
'Uses the one from model dir instead.')
stt_group.add_argument('--stt-workers', type=int, required=False, default=1,
help='Number of parallel STT workers - should 1 for GPU based DeepSpeech')
help='Number of parallel STT workers - should be 1 for GPU based DeepSpeech')
stt_group.add_argument('--stt-min-duration', type=int, required=False, default=100,
help='Minimum speech fragment duration in milliseconds to translate (default: 100)')
stt_group.add_argument('--stt-max-duration', type=int, required=False,