This commit is contained in:
Alexandre Lissy 2019-10-14 14:25:02 +02:00
Родитель ef3bdb2540
Коммит 1939f74ec0
13 изменённых файлов: 125 добавлений и 46 удалений

Просмотреть файл

@ -771,6 +771,20 @@ def export():
from tensorflow.python.framework.ops import Tensor, Operation
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
assert graph_version > 0
# Reshape with dimension [1] required to avoid this error:
# ERROR: Input array not provided for operation 'reshape'.
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('ascii')], name='metadata_language')
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = ",".join(output_names_tensors + output_names_ops)
@ -813,24 +827,12 @@ def export():
output_node_names=output_node_names.split(','),
placeholder_type_enum=tf.float32.as_datatype_enum)
frozen_graph = do_graph_freeze(output_node_names=output_names)
if not FLAGS.export_tflite:
frozen_graph = do_graph_freeze(output_node_names=output_names)
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
# Add a no-op node to the graph with metadata information to be loaded by the native client
metadata = frozen_graph.node.add()
metadata.name = 'model_metadata'
metadata.op = 'NoOp'
metadata.attr['sample_rate'].i = FLAGS.audio_sample_rate
metadata.attr['feature_win_len'].i = FLAGS.feature_win_len
metadata.attr['feature_win_step'].i = FLAGS.feature_win_step
if FLAGS.export_language:
metadata.attr['language'].s = FLAGS.export_language.encode('ascii')
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
frozen_graph = do_graph_freeze(output_node_names=output_names)
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())

Просмотреть файл

@ -1 +1 @@
3
4

Просмотреть файл

@ -12,9 +12,9 @@ ModelState::ModelState()
, n_context_(-1)
, n_features_(-1)
, mfcc_feats_per_timestep_(-1)
, sample_rate_(DEFAULT_SAMPLE_RATE)
, audio_win_len_(DEFAULT_WINDOW_LENGTH)
, audio_win_step_(DEFAULT_WINDOW_STEP)
, sample_rate_(-1)
, audio_win_len_(-1)
, audio_win_step_(-1)
, state_size_(-1)
{
}

Просмотреть файл

@ -15,10 +15,6 @@ struct ModelState {
//TODO: infer batch size from model/use dynamic batch size
static constexpr unsigned int BATCH_SIZE = 1;
static constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000;
static constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
static constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
Alphabet alphabet_;
std::unique_ptr<Scorer> scorer_;
unsigned int beam_width_;

Просмотреть файл

@ -1,5 +1,7 @@
#include "tflitemodelstate.h"
#include "workspace_status.h"
using namespace tflite;
using std::vector;
@ -123,6 +125,23 @@ TFLiteModelState::init(const char* model_path,
new_state_h_idx_ = get_output_tensor_by_name("new_state_h");
mfccs_idx_ = get_output_tensor_by_name("mfccs");
int metadata_version_idx = get_output_tensor_by_name("metadata_version");
// int metadata_language_idx = get_output_tensor_by_name("metadata_language");
int metadata_sample_rate_idx = get_output_tensor_by_name("metadata_sample_rate");
int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len");
int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step");
std::vector<int> metadata_exec_plan;
metadata_exec_plan.push_back(find_parent_node_ids(metadata_version_idx)[0]);
// metadata_exec_plan.push_back(find_parent_node_ids(metadata_language_idx)[0]);
metadata_exec_plan.push_back(find_parent_node_ids(metadata_sample_rate_idx)[0]);
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]);
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]);
for (int i = 0; i < metadata_exec_plan.size(); ++i) {
assert(metadata_exec_plan[i] > -1);
}
// When we call Interpreter::Invoke, the whole graph is executed by default,
// which means every time compute_mfcc is called the entire acoustic model is
// also executed. To workaround that problem, we walk up the dependency DAG
@ -131,15 +150,60 @@ TFLiteModelState::init(const char* model_path,
auto mfcc_plan = find_parent_node_ids(mfccs_idx_);
auto orig_plan = interpreter_->execution_plan();
// Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan
auto erase_begin = std::remove_if(orig_plan.begin(), orig_plan.end(), [&mfcc_plan](int elem) {
return std::find(mfcc_plan.begin(), mfcc_plan.end(), elem) != mfcc_plan.end();
// Remove MFCC and Metatda nodes from original plan (all nodes) to create the acoustic model plan
auto erase_begin = std::remove_if(orig_plan.begin(), orig_plan.end(), [&mfcc_plan, &metadata_exec_plan](int elem) {
return (std::find(mfcc_plan.begin(), mfcc_plan.end(), elem) != mfcc_plan.end()
|| std::find(metadata_exec_plan.begin(), metadata_exec_plan.end(), elem) != metadata_exec_plan.end());
});
orig_plan.erase(erase_begin, orig_plan.end());
acoustic_exec_plan_ = std::move(orig_plan);
mfcc_exec_plan_ = std::move(mfcc_plan);
interpreter_->SetExecutionPlan(metadata_exec_plan);
TfLiteStatus status = interpreter_->Invoke();
if (status != kTfLiteOk) {
std::cerr << "Error running session: " << status << "\n";
return DS_ERR_FAIL_INTERPRETER;
}
int* const graph_version = interpreter_->typed_tensor<int>(metadata_version_idx);
if (graph_version == nullptr) {
std::cerr << "Unable to read model file version." << std::endl;
return DS_ERR_MODEL_INCOMPATIBLE;
}
if (*graph_version < ds_graph_version()) {
std::cerr << "Specified model file version (" << *graph_version << ") is "
<< "incompatible with minimum version supported by this client ("
<< ds_graph_version() << "). See "
<< "https://github.com/mozilla/DeepSpeech/blob/master/USING.rst#model-compatibility "
<< "for more information" << std::endl;
return DS_ERR_MODEL_INCOMPATIBLE;
}
int* const model_sample_rate = interpreter_->typed_tensor<int>(metadata_sample_rate_idx);
if (model_sample_rate == nullptr) {
std::cerr << "Unable to read model sample rate." << std::endl;
return DS_ERR_MODEL_INCOMPATIBLE;
}
sample_rate_ = *model_sample_rate;
int* const win_len_ms = interpreter_->typed_tensor<int>(metadata_feature_win_len_idx);
int* const win_step_ms = interpreter_->typed_tensor<int>(metadata_feature_win_step_idx);
if (win_len_ms == nullptr || win_step_ms == nullptr) {
std::cerr << "Unable to read model feature window informations." << std::endl;
return DS_ERR_MODEL_INCOMPATIBLE;
}
audio_win_len_ = sample_rate_ * (*win_len_ms / 1000.0);
audio_win_step_ = sample_rate_ * (*win_step_ms / 1000.0);
assert(sample_rate_ > 0);
assert(audio_win_len_ > 0);
assert(audio_win_step_ > 0);
TfLiteIntArray* dims_input_node = interpreter_->tensor(input_node_idx_)->dims;
n_steps_ = dims_input_node->data[1];

Просмотреть файл

@ -78,7 +78,20 @@ TFModelState::init(const char* model_path,
return DS_ERR_FAIL_CREATE_SESS;
}
int graph_version = graph_def_.version();
std::vector<tensorflow::Tensor> metadata_outputs;
status = session_->Run({}, {
"metadata_version",
// "metadata_language",
"metadata_sample_rate",
"metadata_feature_win_len",
"metadata_feature_win_step"
}, {}, &metadata_outputs);
if (!status.ok()) {
std::cout << "Unable to fetch metadata: " << status << std::endl;
return DS_ERR_MODEL_INCOMPATIBLE;
}
int graph_version = metadata_outputs[0].scalar<int>()();
if (graph_version < ds_graph_version()) {
std::cerr << "Specified model file version (" << graph_version << ") is "
<< "incompatible with minimum version supported by this client ("
@ -88,6 +101,16 @@ TFModelState::init(const char* model_path,
return DS_ERR_MODEL_INCOMPATIBLE;
}
sample_rate_ = metadata_outputs[1].scalar<int>()();
int win_len_ms = metadata_outputs[2].scalar<int>()();
int win_step_ms = metadata_outputs[3].scalar<int>()();
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
assert(sample_rate_ > 0);
assert(audio_win_len_ > 0);
assert(audio_win_step_ > 0);
for (int i = 0; i < graph_def_.node_size(); ++i) {
NodeDef node = graph_def_.node(i);
if (node.name() == "input_node") {
@ -115,12 +138,6 @@ TFModelState::init(const char* model_path,
<< std::endl;
return DS_ERR_INVALID_ALPHABET;
}
} else if (node.name() == "model_metadata") {
sample_rate_ = node.attr().at("sample_rate").i();
int win_len_ms = node.attr().at("feature_win_len").i();
int win_step_ms = node.attr().at("feature_win_step").i();
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
}
}

Просмотреть файл

@ -30,7 +30,7 @@ then:
image: ${build.docker_image}
env:
DEEPSPEECH_MODEL: "https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.8/models.tar.gz"
DEEPSPEECH_MODEL: "https://github.com/lissyx/DeepSpeech/releases/download/test-model-0.6.0a10/models.tar.gz"
DEEPSPEECH_AUDIO: "https://github.com/mozilla/DeepSpeech/releases/download/v0.4.1/audio-0.4.1.tar.gz"
PIP_DEFAULT_TIMEOUT: "60"

Просмотреть файл

@ -252,12 +252,12 @@ assert_correct_multi_ldc93s1()
assert_correct_ldc93s1_prodmodel()
{
assert_correct_inference "$1" "she had reduce suit in greasy water all year" "$2"
assert_correct_inference "$1" "she had i do so in greasy wash for a year" "$2"
}
assert_correct_ldc93s1_prodmodel_stereo_44k()
{
assert_correct_inference "$1" "she had reduce suit in greasy water all year" "$2"
assert_correct_inference "$1" "she had the doctor in greasy wash for a year" "$2"
}
assert_correct_warning_upsampling()
@ -436,7 +436,7 @@ run_prod_concurrent_stream_tests()
output2=$(echo "${output}" | tail -n 1)
assert_correct_ldc93s1_prodmodel "${output1}" "${status}"
assert_correct_inference "${output2}" "i must find a new home in the stars" "${status}"
assert_correct_inference "${output2}" "we must find a new home in the stars" "${status}"
}
run_prod_inference_tests()

Просмотреть файл

@ -38,8 +38,8 @@ then:
DEEPSPEECH_ARTIFACTS_ROOT: https://queue.taskcluster.net/v1/task/${linux_arm64_build}/artifacts/public
DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package_cpu}/artifacts/public
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pb
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pbmm
DEEPSPEECH_PROD_MODEL: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pb
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pbmm
PIP_DEFAULT_TIMEOUT: "60"
PIP_EXTRA_INDEX_URL: "https://lissyx.github.io/deepspeech-python-wheels/"
EXTRA_PYTHON_CONFIGURE_OPTS: "" # Required by Debian Buster

Просмотреть файл

@ -43,8 +43,8 @@ then:
DEEPSPEECH_ARTIFACTS_TFLITE_ROOT: https://queue.taskcluster.net/v1/task/${darwin_amd64_tflite}/artifacts/public
DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package_cpu}/artifacts/public
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pb
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pbmm
DEEPSPEECH_PROD_MODEL: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pb
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pbmm
EXPECTED_TENSORFLOW_VERSION: "${build.tensorflow_git_desc}"
command:

Просмотреть файл

@ -43,8 +43,8 @@ then:
DEEPSPEECH_ARTIFACTS_TFLITE_ROOT: https://queue.taskcluster.net/v1/task/${linux_amd64_tflite}/artifacts/public
DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package_cpu}/artifacts/public
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pb
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pbmm
DEEPSPEECH_PROD_MODEL: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pb
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pbmm
DECODER_ARTIFACTS_ROOT: https://queue.taskcluster.net/v1/task/${linux_amd64_ctc}/artifacts/public
PIP_DEFAULT_TIMEOUT: "60"
EXPECTED_TENSORFLOW_VERSION: "${build.tensorflow_git_desc}"

Просмотреть файл

@ -38,8 +38,8 @@ then:
DEEPSPEECH_ARTIFACTS_ROOT: https://queue.taskcluster.net/v1/task/${linux_rpi3_build}/artifacts/public
DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package_cpu}/artifacts/public
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pb
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pbmm
DEEPSPEECH_PROD_MODEL: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pb
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pbmm
PIP_DEFAULT_TIMEOUT: "60"
PIP_EXTRA_INDEX_URL: "https://www.piwheels.org/simple"
EXTRA_PYTHON_CONFIGURE_OPTS: "" # Required by Raspbian Buster / PiWheels

Просмотреть файл

@ -45,8 +45,8 @@ then:
DEEPSPEECH_ARTIFACTS_TFLITE_ROOT: https://queue.taskcluster.net/v1/task/${win_amd64_tflite}/artifacts/public
DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package_cpu}/artifacts/public
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pb
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pbmm
DEEPSPEECH_PROD_MODEL: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pb
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pbmm
EXPECTED_TENSORFLOW_VERSION: "${build.tensorflow_git_desc}"
TC_MSYS_VERSION: 'MSYS_NT-6.3'
MSYS: 'winsymlinks:nativestrict'