Add version info to exported graphs

This commit is contained in:
Reuben Morais 2019-04-02 10:29:57 -03:00
Родитель 4e9e78fefe
Коммит a7cda8e761
9 изменённых файлов: 232 добавлений и 110 удалений

Просмотреть файл

@ -665,85 +665,88 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
)
def file_relative_read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
def export():
r'''
Restores the trained variables into a simpler graph that will be exported for serving.
'''
log_info('Exporting the model...')
with tf.device('/cpu:0'):
from tensorflow.python.framework.ops import Tensor, Operation
from tensorflow.python.framework.ops import Tensor, Operation
tf.reset_default_graph()
session = tf.Session(config=Config.session_config)
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
input_names = ",".join(tensor.op.name for tensor in inputs.values())
output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation)]
output_names = ",".join(output_names_tensors + output_names_ops)
input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values())
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
input_names = ",".join(tensor.op.name for tensor in inputs.values())
output_names_tensors = [ tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor) ]
output_names_ops = [ tensor.name for tensor in outputs.values() if isinstance(tensor, Operation) ]
output_names = ",".join(output_names_tensors + output_names_ops)
input_shapes = ":".join(",".join(map(str, tensor.shape)) for tensor in inputs.values())
if not FLAGS.export_tflite:
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
else:
# Create a saver using variables from the above newly created graph
def fixup(name):
if name.startswith('rnn/lstm_cell/'):
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
return name
mapping = {fixup(v.op.name): v for v in tf.global_variables()}
saver = tf.train.Saver(mapping)
# Restore variables from training checkpoint
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
checkpoint_path = checkpoint.model_checkpoint_path
output_filename = 'output_graph.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
try:
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
return freeze_graph.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_file,
clear_devices=False,
variable_names_blacklist=variables_blacklist,
initializer_nodes='')
if not FLAGS.export_tflite:
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
# Create a saver using variables from the above newly created graph
def fixup(name):
if name.startswith('rnn/lstm_cell/'):
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
return name
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
mapping = {fixup(v.op.name): v for v in tf.global_variables()}
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.post_training_quantize = True
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()
saver = tf.train.Saver(mapping)
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
# Restore variables from training checkpoint
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
checkpoint_path = checkpoint.model_checkpoint_path
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
output_filename = 'output_graph.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
try:
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
return freeze_graph.freeze_graph_with_def_protos(
input_graph_def=session.graph_def,
input_saver_def=saver.as_saver_def(),
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_file,
clear_devices=False,
variable_names_blacklist=variables_blacklist,
initializer_nodes='')
if not FLAGS.export_tflite:
do_graph_freeze(output_file=output_graph_path, output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
else:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.post_training_quantize = True
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True
tflite_model = converter.convert()
with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model)
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
log_info('Models exported at %s' % (FLAGS.export_dir))
except RuntimeError as e:
log_error(str(e))
log_info('Models exported at %s' % (FLAGS.export_dir))
except RuntimeError as e:
log_error(str(e))
def do_single_file_inference(input_file_path):
@ -795,18 +798,20 @@ def main(_):
initialize_globals()
if FLAGS.train:
with tf.Graph().as_default():
tf.set_random_seed(FLAGS.random_seed)
train()
tf.reset_default_graph()
tf.set_random_seed(FLAGS.random_seed)
train()
if FLAGS.test:
with tf.Graph().as_default():
test()
tf.reset_default_graph()
test()
if FLAGS.export_dir:
tf.reset_default_graph()
export()
if len(FLAGS.one_shot_infer):
tf.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)
if __name__ == '__main__' :

1
GRAPH_VERSION Normal file
Просмотреть файл

@ -0,0 +1 @@
1

Просмотреть файл

@ -33,10 +33,11 @@ See the output of `deepspeech -h` for more information on the use of `deepspeech
- [Prerequisites](#prerequisites)
- [Getting the code](#getting-the-code)
- [Getting the pre-trained model](#getting-the-pre-trained-model)
- [CUDA dependency](#cuda-dependency)
- [Using the model](#using-the-model)
- [CUDA dependency](#cuda-dependency)
- [Model compatibility](#model-compatibility)
- [Using the Python package](#using-the-python-package)
- [Using the command line client](#using-the-command-line-client)
- [Using the command-line client](#using-the-command-line-client)
- [Using the Node.JS package](#using-the-nodejs-package)
- [Installing bindings from source](#installing-bindings-from-source)
- [Third party bindings](#third-party-bindings)
@ -48,6 +49,7 @@ See the output of `deepspeech -h` for more information on the use of `deepspeech
- [Checkpointing](#checkpointing)
- [Exporting a model for inference](#exporting-a-model-for-inference)
- [Exporting a model for TFLite](#exporting-a-model-for-tflite)
- [Making a mmap-able model for inference](#making-a-mmap-able-model-for-inference)
- [Continuing training from a release model](#continuing-training-from-a-release-model)
- [Contact/Getting Help](#contactgetting-help)
@ -88,6 +90,10 @@ There are three ways to use DeepSpeech inference:
The GPU capable builds (Python, NodeJS, C++ etc) depend on the same CUDA runtime as upstream TensorFlow. Currently with TensorFlow r1.12 it depends on CUDA 9.0 and CuDNN v7.2.
### Model compatibility
DeepSpeech models are versioned to keep you from trying to use an incompatible graph with a newer client after a breaking change was made to the code. If you get an error saying your model file version is too old for the client, you should either upgrade to a newer model release, re-export your model from the checkpoint using a newer version of the code, or downgrade your client if you need to use the old model and can't re-export it.
### Using the Python package
Pre-built binaries which can be used for performing inference with a trained model can be installed with `pip3`. You can then use the `deepspeech` binary to do speech-to-text on an audio file:

Просмотреть файл

@ -21,6 +21,14 @@ genrule(
local = 1,
)
genrule(
name = "ds_graph_version",
outs = ["ds_graph_version.h"],
cmd = "$(location :ds_graph_version.sh) >$@",
tools = [":ds_graph_version.sh"],
local = 1,
)
KENLM_SOURCES = glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc",
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"],
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"])
@ -62,7 +70,8 @@ tf_cc_shared_object(
srcs = ["deepspeech.cc",
"deepspeech.h",
"alphabet.h",
"ds_version.h"] +
"ds_version.h",
"ds_graph_version.h"] +
DECODER_SOURCES,
copts = select({
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default

Просмотреть файл

@ -16,7 +16,7 @@ const float NUM_FLT_LOGE = 0.4342944819;
inline void check(
bool x, const char *expr, const char *file, int line, const char *err) {
if (!x) {
std::cout << "[" << file << ":" << line << "] ";
std::cerr << "[" << file << ":" << line << "] ";
LOG(FATAL) << "\"" << expr << "\" check failed. " << err;
}
}

Просмотреть файл

@ -13,6 +13,7 @@
#include "alphabet.h"
#include "native_client/ds_version.h"
#include "native_client/ds_graph_version.h"
#ifndef USE_TFLITE
#include "tensorflow/core/public/session.h"
@ -654,6 +655,16 @@ DS_CreateModel(const char* aModelPath,
return DS_ERR_FAIL_CREATE_SESS;
}
int graph_version = model->graph_def.version();
if (graph_version < DS_GRAPH_VERSION) {
std::cerr << "Specified model file version (" << graph_version << ") is "
<< "incompatible with minimum version supported by this client ("
<< DS_GRAPH_VERSION << "). See "
<< "https://github.com/mozilla/DeepSpeech/#model-compatibility "
<< "for more information" << std::endl;
return DS_ERR_MODEL_INCOMPATIBLE;
}
for (int i = 0; i < model->graph_def.node_size(); ++i) {
NodeDef node = model->graph_def.node(i);
if (node.name() == "input_node") {

Просмотреть файл

@ -40,6 +40,7 @@ enum DeepSpeech_Error_Codes
DS_ERR_INVALID_ALPHABET = 0x2000,
DS_ERR_INVALID_SHAPE = 0x2001,
DS_ERR_INVALID_LM = 0x2002,
DS_ERR_MODEL_INCOMPATIBLE = 0x2003,
// Runtime failures
DS_ERR_FAIL_INIT_MMAP = 0x3000,

Просмотреть файл

@ -0,0 +1,19 @@
#!/bin/bash
if [ `uname` = "Darwin" ]; then
export PATH="/Users/build-user/TaskCluster/Workdir/tasks/tc-workdir/homebrew/opt/coreutils/libexec/gnubin:${PATH}"
fi
DS_DIR="$(realpath "$(dirname "$(realpath "$0")")/../")"
if [ ! -d "${DS_DIR}" ]; then
exit 1
fi;
DS_GRAPH_VERSION=$(cat "${DS_DIR}/GRAPH_VERSION")
if [ $? -ne 0 ]; then
exit 1
fi
cat <<EOF
#define DS_GRAPH_VERSION ${DS_GRAPH_VERSION}
EOF

Просмотреть файл

@ -67,6 +67,22 @@ assert_correct_inference()
{
phrase=$(strip "$1")
expected=$(strip "$2")
status=$3
if [ "$status" -ne "0" ]; then
case "$(cat /tmp/stderr)" in
*"incompatible with minimum version"*)
echo "Prod model too old for client, skipping test."
return 0
;;
*)
echo "Client failed to run:"
cat /tmp/stderr
return 1
;;
esac
fi
if [ -z "${phrase}" -o -z "${expected}" ]; then
echo "One or more empty strings:"
@ -95,6 +111,7 @@ assert_working_inference()
{
phrase=$1
expected=$2
status=$3
if [ -z "${phrase}" -o -z "${expected}" ]; then
echo "One or more empty strings:"
@ -103,6 +120,21 @@ assert_working_inference()
return 1
fi;
if [ "$status" -ne "0" ]; then
case "$(cat /tmp/stderr)" in
*"incompatible with minimum version"*)
echo "Prod model too old for client, skipping test."
return 0
;;
*)
echo "Client failed to run:"
cat /tmp/stderr
return 1
;;
esac
fi
case "${phrase}" in
*${expected}*)
echo "Proper output has been produced:"
@ -186,40 +218,40 @@ assert_not_present()
assert_correct_ldc93s1()
{
assert_correct_inference "$1" "she had your dark suit in greasy wash water all year"
assert_correct_inference "$1" "she had your dark suit in greasy wash water all year" "$2"
}
assert_working_ldc93s1()
{
assert_working_inference "$1" "she had your dark suit in greasy wash water all year"
assert_working_inference "$1" "she had your dark suit in greasy wash water all year" "$2"
}
assert_correct_ldc93s1_lm()
{
assert_correct_inference "$1" "she had your dark suit in greasy wash water all year"
assert_correct_inference "$1" "she had your dark suit in greasy wash water all year" "$2"
}
assert_working_ldc93s1_lm()
{
assert_working_inference "$1" "she had your dark suit in greasy wash water all year"
assert_working_inference "$1" "she had your dark suit in greasy wash water all year" "$2"
}
assert_correct_multi_ldc93s1()
{
assert_shows_something "$1" "/LDC93S1.wav%she had your dark suit in greasy wash water all year%"
assert_shows_something "$1" "/LDC93S1_pcms16le_2_44100.wav%she had your dark suit in greasy wash water all year%"
assert_shows_something "$1" "/LDC93S1.wav%she had your dark suit in greasy wash water all year%" "$?"
assert_shows_something "$1" "/LDC93S1_pcms16le_2_44100.wav%she had your dark suit in greasy wash water all year%" "$?"
## 8k will output garbage anyway ...
# assert_shows_something "$1" "/LDC93S1_pcms16le_1_8000.wav%she hayorasryrtl lyreasy asr watal w water all year%"
}
assert_correct_ldc93s1_prodmodel()
{
assert_correct_inference "$1" "she had a due and greasy wash water year"
assert_correct_inference "$1" "she had a due and greasy wash water year" "$2"
}
assert_correct_ldc93s1_prodmodel_stereo_44k()
{
assert_correct_inference "$1" "she had a due and greasy wash water year"
assert_correct_inference "$1" "she had a due and greasy wash water year" "$2"
}
assert_correct_warning_upsampling()
@ -249,43 +281,66 @@ check_tensorflow_version()
run_tflite_basic_inference_tests()
{
phrase_pbmodel_nolm=$(${DS_BINARY_PREFIX}deepspeech --model ${ANDROID_TMP_DIR}/ds/${model_name} --alphabet ${ANDROID_TMP_DIR}/ds/alphabet.txt --audio ${ANDROID_TMP_DIR}/ds/LDC93S1.wav)
assert_correct_ldc93s1 "${phrase_pbmodel_nolm}"
set +e
phrase_pbmodel_nolm=$(${DS_BINARY_PREFIX}deepspeech --model ${ANDROID_TMP_DIR}/ds/${model_name} --alphabet ${ANDROID_TMP_DIR}/ds/alphabet.txt --audio ${ANDROID_TMP_DIR}/ds/LDC93S1.wav 2>/tmp/stderr)
set -e
assert_correct_ldc93s1 "${phrase_pbmodel_nolm}" "$?"
}
run_netframework_inference_tests()
{
phrase_pbmodel_nolm=$(DeepSpeechConsole.exe --model ${TASKCLUSTER_TMP_DIR}/${model_name} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav)
assert_working_ldc93s1 "${phrase_pbmodel_nolm}"
set +e
phrase_pbmodel_nolm=$(DeepSpeechConsole.exe --model ${TASKCLUSTER_TMP_DIR}/${model_name} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav 2>/tmp/stderr)
set -e
assert_working_ldc93s1 "${phrase_pbmodel_nolm}" "$?"
phrase_pbmodel_nolm=$(DeepSpeechConsole.exe --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav)
assert_working_ldc93s1 "${phrase_pbmodel_nolm}"
set +e
phrase_pbmodel_nolm=$(DeepSpeechConsole.exe --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav 2>/tmp/stderr)
set -e
assert_working_ldc93s1 "${phrase_pbmodel_nolm}" "$?"
phrase_pbmodel_withlm=$(DeepSpeechConsole.exe --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav)
assert_working_ldc93s1_lm "${phrase_pbmodel_withlm}"
set +e
phrase_pbmodel_withlm=$(DeepSpeechConsole.exe --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav 2>/tmp/stderr)
set -e
assert_working_ldc93s1_lm "${phrase_pbmodel_withlm}" "$?"
}
run_basic_inference_tests()
{
phrase_pbmodel_nolm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav)
assert_correct_ldc93s1 "${phrase_pbmodel_nolm}"
set +e
phrase_pbmodel_nolm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav 2>/tmp/stderr)
status=$?
set -e
assert_correct_ldc93s1 "${phrase_pbmodel_nolm}" "$status"
phrase_pbmodel_nolm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav)
assert_correct_ldc93s1 "${phrase_pbmodel_nolm}"
set +e
phrase_pbmodel_nolm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav 2>/tmp/stderr)
status=$?
set -e
assert_correct_ldc93s1 "${phrase_pbmodel_nolm}" "$status"
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav)
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm}"
set +e
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav 2>/tmp/stderr)
status=$?
set -e
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm}" "$status"
}
run_all_inference_tests()
{
run_basic_inference_tests
phrase_pbmodel_nolm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav)
assert_correct_ldc93s1 "${phrase_pbmodel_nolm_stereo_44k}"
set +e
phrase_pbmodel_nolm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>/tmp/stderr)
status=$?
set -e
assert_correct_ldc93s1 "${phrase_pbmodel_nolm_stereo_44k}" "$status"
phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav)
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_stereo_44k}"
set +e
phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>/tmp/stderr)
status=$?
set -e
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_stereo_44k}" "$status"
phrase_pbmodel_nolm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
assert_correct_warning_upsampling "${phrase_pbmodel_nolm_mono_8k}"
@ -296,14 +351,23 @@ run_all_inference_tests()
run_prod_inference_tests()
{
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav)
assert_correct_ldc93s1_prodmodel "${phrase_pbmodel_withlm}"
set +e
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav 2>/tmp/stderr)
status=$?
set -e
assert_correct_ldc93s1_prodmodel "${phrase_pbmodel_withlm}" "$status"
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav)
assert_correct_ldc93s1_prodmodel "${phrase_pbmodel_withlm}"
set +e
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1.wav 2>/tmp/stderr)
status=$?
set -e
assert_correct_ldc93s1_prodmodel "${phrase_pbmodel_withlm}" "$status"
phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav)
assert_correct_ldc93s1_prodmodel_stereo_44k "${phrase_pbmodel_withlm_stereo_44k}"
set +e
phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>/tmp/stderr)
status=$?
set -e
assert_correct_ldc93s1_prodmodel_stereo_44k "${phrase_pbmodel_withlm_stereo_44k}" "$status"
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
@ -311,11 +375,17 @@ run_prod_inference_tests()
run_multi_inference_tests()
{
multi_phrase_pbmodel_nolm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/ | tr '\n' '%')
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_nolm}"
set +e -o pipefail
multi_phrase_pbmodel_nolm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/ 2>/tmp/stderr | tr '\n' '%')
status=$?
set -e +o pipefail
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_nolm}" "$status"
multi_phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/ | tr '\n' '%')
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}"
set +e -o pipefail
multi_phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/ 2>/tmp/stderr | tr '\n' '%')
status=$?
set -e +o pipefail
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}" "$status"
}
android_run_tests()