зеркало из https://github.com/mozilla/DeepSpeech.git
Store graph version on TFLite
This commit is contained in:
Родитель
ef3bdb2540
Коммит
1939f74ec0
|
@ -771,6 +771,20 @@ def export():
|
|||
from tensorflow.python.framework.ops import Tensor, Operation
|
||||
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=FLAGS.export_batch_size, n_steps=FLAGS.n_steps, tflite=FLAGS.export_tflite)
|
||||
|
||||
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||
assert graph_version > 0
|
||||
|
||||
# Reshape with dimension [1] required to avoid this error:
|
||||
# ERROR: Input array not provided for operation 'reshape'.
|
||||
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
|
||||
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||
|
||||
if FLAGS.export_language:
|
||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('ascii')], name='metadata_language')
|
||||
|
||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||
output_names = ",".join(output_names_tensors + output_names_ops)
|
||||
|
@ -813,24 +827,12 @@ def export():
|
|||
output_node_names=output_node_names.split(','),
|
||||
placeholder_type_enum=tf.float32.as_datatype_enum)
|
||||
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names)
|
||||
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||
|
||||
# Add a no-op node to the graph with metadata information to be loaded by the native client
|
||||
metadata = frozen_graph.node.add()
|
||||
metadata.name = 'model_metadata'
|
||||
metadata.op = 'NoOp'
|
||||
metadata.attr['sample_rate'].i = FLAGS.audio_sample_rate
|
||||
metadata.attr['feature_win_len'].i = FLAGS.feature_win_len
|
||||
metadata.attr['feature_win_step'].i = FLAGS.feature_win_step
|
||||
if FLAGS.export_language:
|
||||
metadata.attr['language'].s = FLAGS.export_language.encode('ascii')
|
||||
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
fout.write(frozen_graph.SerializeToString())
|
||||
else:
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names)
|
||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||
|
||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||
|
|
|
@ -1 +1 @@
|
|||
3
|
||||
4
|
||||
|
|
|
@ -12,9 +12,9 @@ ModelState::ModelState()
|
|||
, n_context_(-1)
|
||||
, n_features_(-1)
|
||||
, mfcc_feats_per_timestep_(-1)
|
||||
, sample_rate_(DEFAULT_SAMPLE_RATE)
|
||||
, audio_win_len_(DEFAULT_WINDOW_LENGTH)
|
||||
, audio_win_step_(DEFAULT_WINDOW_STEP)
|
||||
, sample_rate_(-1)
|
||||
, audio_win_len_(-1)
|
||||
, audio_win_step_(-1)
|
||||
, state_size_(-1)
|
||||
{
|
||||
}
|
||||
|
|
|
@ -15,10 +15,6 @@ struct ModelState {
|
|||
//TODO: infer batch size from model/use dynamic batch size
|
||||
static constexpr unsigned int BATCH_SIZE = 1;
|
||||
|
||||
static constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000;
|
||||
static constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
|
||||
static constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
|
||||
|
||||
Alphabet alphabet_;
|
||||
std::unique_ptr<Scorer> scorer_;
|
||||
unsigned int beam_width_;
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#include "tflitemodelstate.h"
|
||||
|
||||
#include "workspace_status.h"
|
||||
|
||||
using namespace tflite;
|
||||
using std::vector;
|
||||
|
||||
|
@ -123,6 +125,23 @@ TFLiteModelState::init(const char* model_path,
|
|||
new_state_h_idx_ = get_output_tensor_by_name("new_state_h");
|
||||
mfccs_idx_ = get_output_tensor_by_name("mfccs");
|
||||
|
||||
int metadata_version_idx = get_output_tensor_by_name("metadata_version");
|
||||
// int metadata_language_idx = get_output_tensor_by_name("metadata_language");
|
||||
int metadata_sample_rate_idx = get_output_tensor_by_name("metadata_sample_rate");
|
||||
int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len");
|
||||
int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step");
|
||||
|
||||
std::vector<int> metadata_exec_plan;
|
||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_version_idx)[0]);
|
||||
// metadata_exec_plan.push_back(find_parent_node_ids(metadata_language_idx)[0]);
|
||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_sample_rate_idx)[0]);
|
||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]);
|
||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]);
|
||||
|
||||
for (int i = 0; i < metadata_exec_plan.size(); ++i) {
|
||||
assert(metadata_exec_plan[i] > -1);
|
||||
}
|
||||
|
||||
// When we call Interpreter::Invoke, the whole graph is executed by default,
|
||||
// which means every time compute_mfcc is called the entire acoustic model is
|
||||
// also executed. To workaround that problem, we walk up the dependency DAG
|
||||
|
@ -131,15 +150,60 @@ TFLiteModelState::init(const char* model_path,
|
|||
auto mfcc_plan = find_parent_node_ids(mfccs_idx_);
|
||||
auto orig_plan = interpreter_->execution_plan();
|
||||
|
||||
// Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan
|
||||
auto erase_begin = std::remove_if(orig_plan.begin(), orig_plan.end(), [&mfcc_plan](int elem) {
|
||||
return std::find(mfcc_plan.begin(), mfcc_plan.end(), elem) != mfcc_plan.end();
|
||||
// Remove MFCC and Metatda nodes from original plan (all nodes) to create the acoustic model plan
|
||||
auto erase_begin = std::remove_if(orig_plan.begin(), orig_plan.end(), [&mfcc_plan, &metadata_exec_plan](int elem) {
|
||||
return (std::find(mfcc_plan.begin(), mfcc_plan.end(), elem) != mfcc_plan.end()
|
||||
|| std::find(metadata_exec_plan.begin(), metadata_exec_plan.end(), elem) != metadata_exec_plan.end());
|
||||
});
|
||||
orig_plan.erase(erase_begin, orig_plan.end());
|
||||
|
||||
acoustic_exec_plan_ = std::move(orig_plan);
|
||||
mfcc_exec_plan_ = std::move(mfcc_plan);
|
||||
|
||||
interpreter_->SetExecutionPlan(metadata_exec_plan);
|
||||
TfLiteStatus status = interpreter_->Invoke();
|
||||
if (status != kTfLiteOk) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return DS_ERR_FAIL_INTERPRETER;
|
||||
}
|
||||
|
||||
int* const graph_version = interpreter_->typed_tensor<int>(metadata_version_idx);
|
||||
if (graph_version == nullptr) {
|
||||
std::cerr << "Unable to read model file version." << std::endl;
|
||||
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
if (*graph_version < ds_graph_version()) {
|
||||
std::cerr << "Specified model file version (" << *graph_version << ") is "
|
||||
<< "incompatible with minimum version supported by this client ("
|
||||
<< ds_graph_version() << "). See "
|
||||
<< "https://github.com/mozilla/DeepSpeech/blob/master/USING.rst#model-compatibility "
|
||||
<< "for more information" << std::endl;
|
||||
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
int* const model_sample_rate = interpreter_->typed_tensor<int>(metadata_sample_rate_idx);
|
||||
if (model_sample_rate == nullptr) {
|
||||
std::cerr << "Unable to read model sample rate." << std::endl;
|
||||
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
sample_rate_ = *model_sample_rate;
|
||||
|
||||
int* const win_len_ms = interpreter_->typed_tensor<int>(metadata_feature_win_len_idx);
|
||||
int* const win_step_ms = interpreter_->typed_tensor<int>(metadata_feature_win_step_idx);
|
||||
if (win_len_ms == nullptr || win_step_ms == nullptr) {
|
||||
std::cerr << "Unable to read model feature window informations." << std::endl;
|
||||
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
audio_win_len_ = sample_rate_ * (*win_len_ms / 1000.0);
|
||||
audio_win_step_ = sample_rate_ * (*win_step_ms / 1000.0);
|
||||
|
||||
assert(sample_rate_ > 0);
|
||||
assert(audio_win_len_ > 0);
|
||||
assert(audio_win_step_ > 0);
|
||||
|
||||
TfLiteIntArray* dims_input_node = interpreter_->tensor(input_node_idx_)->dims;
|
||||
|
||||
n_steps_ = dims_input_node->data[1];
|
||||
|
|
|
@ -78,7 +78,20 @@ TFModelState::init(const char* model_path,
|
|||
return DS_ERR_FAIL_CREATE_SESS;
|
||||
}
|
||||
|
||||
int graph_version = graph_def_.version();
|
||||
std::vector<tensorflow::Tensor> metadata_outputs;
|
||||
status = session_->Run({}, {
|
||||
"metadata_version",
|
||||
// "metadata_language",
|
||||
"metadata_sample_rate",
|
||||
"metadata_feature_win_len",
|
||||
"metadata_feature_win_step"
|
||||
}, {}, &metadata_outputs);
|
||||
if (!status.ok()) {
|
||||
std::cout << "Unable to fetch metadata: " << status << std::endl;
|
||||
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
int graph_version = metadata_outputs[0].scalar<int>()();
|
||||
if (graph_version < ds_graph_version()) {
|
||||
std::cerr << "Specified model file version (" << graph_version << ") is "
|
||||
<< "incompatible with minimum version supported by this client ("
|
||||
|
@ -88,6 +101,16 @@ TFModelState::init(const char* model_path,
|
|||
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
sample_rate_ = metadata_outputs[1].scalar<int>()();
|
||||
int win_len_ms = metadata_outputs[2].scalar<int>()();
|
||||
int win_step_ms = metadata_outputs[3].scalar<int>()();
|
||||
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
|
||||
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
|
||||
|
||||
assert(sample_rate_ > 0);
|
||||
assert(audio_win_len_ > 0);
|
||||
assert(audio_win_step_ > 0);
|
||||
|
||||
for (int i = 0; i < graph_def_.node_size(); ++i) {
|
||||
NodeDef node = graph_def_.node(i);
|
||||
if (node.name() == "input_node") {
|
||||
|
@ -115,12 +138,6 @@ TFModelState::init(const char* model_path,
|
|||
<< std::endl;
|
||||
return DS_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
} else if (node.name() == "model_metadata") {
|
||||
sample_rate_ = node.attr().at("sample_rate").i();
|
||||
int win_len_ms = node.attr().at("feature_win_len").i();
|
||||
int win_step_ms = node.attr().at("feature_win_step").i();
|
||||
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
|
||||
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ then:
|
|||
image: ${build.docker_image}
|
||||
|
||||
env:
|
||||
DEEPSPEECH_MODEL: "https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.8/models.tar.gz"
|
||||
DEEPSPEECH_MODEL: "https://github.com/lissyx/DeepSpeech/releases/download/test-model-0.6.0a10/models.tar.gz"
|
||||
DEEPSPEECH_AUDIO: "https://github.com/mozilla/DeepSpeech/releases/download/v0.4.1/audio-0.4.1.tar.gz"
|
||||
PIP_DEFAULT_TIMEOUT: "60"
|
||||
|
||||
|
|
|
@ -252,12 +252,12 @@ assert_correct_multi_ldc93s1()
|
|||
|
||||
assert_correct_ldc93s1_prodmodel()
|
||||
{
|
||||
assert_correct_inference "$1" "she had reduce suit in greasy water all year" "$2"
|
||||
assert_correct_inference "$1" "she had i do so in greasy wash for a year" "$2"
|
||||
}
|
||||
|
||||
assert_correct_ldc93s1_prodmodel_stereo_44k()
|
||||
{
|
||||
assert_correct_inference "$1" "she had reduce suit in greasy water all year" "$2"
|
||||
assert_correct_inference "$1" "she had the doctor in greasy wash for a year" "$2"
|
||||
}
|
||||
|
||||
assert_correct_warning_upsampling()
|
||||
|
@ -436,7 +436,7 @@ run_prod_concurrent_stream_tests()
|
|||
output2=$(echo "${output}" | tail -n 1)
|
||||
|
||||
assert_correct_ldc93s1_prodmodel "${output1}" "${status}"
|
||||
assert_correct_inference "${output2}" "i must find a new home in the stars" "${status}"
|
||||
assert_correct_inference "${output2}" "we must find a new home in the stars" "${status}"
|
||||
}
|
||||
|
||||
run_prod_inference_tests()
|
||||
|
|
|
@ -38,8 +38,8 @@ then:
|
|||
DEEPSPEECH_ARTIFACTS_ROOT: https://queue.taskcluster.net/v1/task/${linux_arm64_build}/artifacts/public
|
||||
DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package_cpu}/artifacts/public
|
||||
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pbmm
|
||||
DEEPSPEECH_PROD_MODEL: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pbmm
|
||||
PIP_DEFAULT_TIMEOUT: "60"
|
||||
PIP_EXTRA_INDEX_URL: "https://lissyx.github.io/deepspeech-python-wheels/"
|
||||
EXTRA_PYTHON_CONFIGURE_OPTS: "" # Required by Debian Buster
|
||||
|
|
|
@ -43,8 +43,8 @@ then:
|
|||
DEEPSPEECH_ARTIFACTS_TFLITE_ROOT: https://queue.taskcluster.net/v1/task/${darwin_amd64_tflite}/artifacts/public
|
||||
DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package_cpu}/artifacts/public
|
||||
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pbmm
|
||||
DEEPSPEECH_PROD_MODEL: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pbmm
|
||||
EXPECTED_TENSORFLOW_VERSION: "${build.tensorflow_git_desc}"
|
||||
|
||||
command:
|
||||
|
|
|
@ -43,8 +43,8 @@ then:
|
|||
DEEPSPEECH_ARTIFACTS_TFLITE_ROOT: https://queue.taskcluster.net/v1/task/${linux_amd64_tflite}/artifacts/public
|
||||
DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package_cpu}/artifacts/public
|
||||
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pbmm
|
||||
DEEPSPEECH_PROD_MODEL: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pbmm
|
||||
DECODER_ARTIFACTS_ROOT: https://queue.taskcluster.net/v1/task/${linux_amd64_ctc}/artifacts/public
|
||||
PIP_DEFAULT_TIMEOUT: "60"
|
||||
EXPECTED_TENSORFLOW_VERSION: "${build.tensorflow_git_desc}"
|
||||
|
|
|
@ -38,8 +38,8 @@ then:
|
|||
DEEPSPEECH_ARTIFACTS_ROOT: https://queue.taskcluster.net/v1/task/${linux_rpi3_build}/artifacts/public
|
||||
DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package_cpu}/artifacts/public
|
||||
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pbmm
|
||||
DEEPSPEECH_PROD_MODEL: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pbmm
|
||||
PIP_DEFAULT_TIMEOUT: "60"
|
||||
PIP_EXTRA_INDEX_URL: "https://www.piwheels.org/simple"
|
||||
EXTRA_PYTHON_CONFIGURE_OPTS: "" # Required by Raspbian Buster / PiWheels
|
||||
|
|
|
@ -45,8 +45,8 @@ then:
|
|||
DEEPSPEECH_ARTIFACTS_TFLITE_ROOT: https://queue.taskcluster.net/v1/task/${win_amd64_tflite}/artifacts/public
|
||||
DEEPSPEECH_NODEJS: https://queue.taskcluster.net/v1/task/${node_package_cpu}/artifacts/public
|
||||
DEEPSPEECH_TEST_MODEL: https://queue.taskcluster.net/v1/task/${training}/artifacts/public/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.4/output_graph.pbmm
|
||||
DEEPSPEECH_PROD_MODEL: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pb
|
||||
DEEPSPEECH_PROD_MODEL_MMAP: https://github.com/lissyx/DeepSpeech/releases/download/prod-metadata-constant/output_graph.pbmm
|
||||
EXPECTED_TENSORFLOW_VERSION: "${build.tensorflow_git_desc}"
|
||||
TC_MSYS_VERSION: 'MSYS_NT-6.3'
|
||||
MSYS: 'winsymlinks:nativestrict'
|
||||
|
|
Загрузка…
Ссылка в новой задаче