2020-02-14 18:24:14 +03:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
from __future__ import absolute_import, print_function
|
|
|
|
|
|
|
|
import absl.app
|
2020-03-25 19:07:29 +03:00
|
|
|
import optuna
|
|
|
|
import sys
|
2020-02-14 18:24:14 +03:00
|
|
|
import tensorflow.compat.v1 as tfv1
|
|
|
|
|
2020-08-25 16:36:22 +03:00
|
|
|
from deepspeech_training.evaluate import evaluate
|
|
|
|
from deepspeech_training.train import create_model
|
|
|
|
from deepspeech_training.util.config import Config, initialize_globals
|
|
|
|
from deepspeech_training.util.flags import create_flags, FLAGS
|
|
|
|
from deepspeech_training.util.logging import log_error
|
|
|
|
from deepspeech_training.util.evaluate_tools import wer_cer_batch
|
2020-08-25 16:31:29 +03:00
|
|
|
from ds_ctcdecoder import Scorer
|
2020-02-14 18:24:14 +03:00
|
|
|
|
|
|
|
|
|
|
|
def character_based():
|
|
|
|
is_character_based = False
|
|
|
|
if FLAGS.scorer_path:
|
|
|
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
|
|
|
is_character_based = scorer.is_utf8_mode()
|
|
|
|
return is_character_based
|
|
|
|
|
|
|
|
def objective(trial):
|
|
|
|
FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max)
|
|
|
|
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)
|
|
|
|
|
|
|
|
is_character_based = trial.study.user_attrs['is_character_based']
|
|
|
|
|
2020-03-18 00:14:47 +03:00
|
|
|
samples = []
|
|
|
|
for step, test_file in enumerate(FLAGS.test_files.split(',')):
|
2020-03-18 00:38:51 +03:00
|
|
|
tfv1.reset_default_graph()
|
|
|
|
|
2020-03-19 00:54:22 +03:00
|
|
|
current_samples = evaluate([test_file], create_model)
|
2020-03-18 00:14:47 +03:00
|
|
|
samples += current_samples
|
|
|
|
|
|
|
|
# Report intermediate objective value.
|
|
|
|
wer, cer = wer_cer_batch(current_samples)
|
|
|
|
trial.report(cer if is_character_based else wer, step)
|
|
|
|
|
|
|
|
# Handle pruning based on the intermediate value.
|
|
|
|
if trial.should_prune():
|
|
|
|
raise optuna.exceptions.TrialPruned()
|
|
|
|
|
2020-02-14 18:24:14 +03:00
|
|
|
wer, cer = wer_cer_batch(samples)
|
|
|
|
return cer if is_character_based else wer
|
|
|
|
|
|
|
|
def main(_):
|
|
|
|
initialize_globals()
|
|
|
|
|
|
|
|
if not FLAGS.test_files:
|
|
|
|
log_error('You need to specify what files to use for evaluation via '
|
|
|
|
'the --test_files flag.')
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
is_character_based = character_based()
|
|
|
|
|
|
|
|
study = optuna.create_study()
|
|
|
|
study.set_user_attr("is_character_based", is_character_based)
|
|
|
|
study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials)
|
|
|
|
print('Best params: lm_alpha={} and lm_beta={} with WER={}'.format(study.best_params['lm_alpha'],
|
|
|
|
study.best_params['lm_beta'],
|
|
|
|
study.best_value))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
create_flags()
|
|
|
|
absl.app.run(main)
|