From c512383aecce02945cd09b6aa83cd715e335e583 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Mon, 27 Jan 2020 16:29:16 +0100 Subject: [PATCH] Fix consumers of DS_CreateModel --- DeepSpeech.py | 2 +- evaluate_tflite.py | 3 +- native_client/args.h | 5 ++- native_client/client.cc | 12 +++++- .../dotnet/DeepSpeechClient/DeepSpeech.cs | 31 ++++++++++++---- .../Interfaces/IDeepSpeech.cs | 15 ++++++++ .../dotnet/DeepSpeechClient/NativeImp.cs | 11 ++++++ .../dotnet/DeepSpeechConsole/Program.cs | 5 +-- .../deepspeech/DeepSpeechActivity.java | 3 +- .../libdeepspeech/test/BasicTest.java | 14 ++++--- .../libdeepspeech/DeepSpeechModel.java | 29 ++++++++++++--- native_client/javascript/client.js | 8 +++- native_client/javascript/index.js | 37 +++++++++++++++++-- native_client/python/__init__.py | 30 ++++++++++++--- native_client/python/client.py | 7 +++- native_client/test/concurrent_streams.py | 8 +--- 16 files changed, 174 insertions(+), 46 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index 48b8edb6..9421e7f0 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -877,7 +877,7 @@ def package_zip(): json.dump({ 'name': FLAGS.export_language, 'parameters': { - 'beamWidth': FLAGS.beam_width, + 'beamWidth': FLAGS.export_beam_width, 'lmAlpha': FLAGS.lm_alpha, 'lmBeta': FLAGS.lm_beta } diff --git a/evaluate_tflite.py b/evaluate_tflite.py index aba6fb68..d01db864 100644 --- a/evaluate_tflite.py +++ b/evaluate_tflite.py @@ -36,7 +36,8 @@ LM_BETA = 1.85 def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask): os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask) - ds = Model(model, BEAM_WIDTH) + ds = Model(model) + ds.setBeamWidth(BEAM_WIDTH) ds.enableExternalScorer(scorer) ds.setScorerAlphaBeta(LM_ALPHA, LM_BETA) diff --git a/native_client/args.h b/native_client/args.h index d5a0f869..2e7306c7 100644 --- a/native_client/args.h +++ b/native_client/args.h @@ -16,7 +16,9 @@ char* scorer = NULL; char* audio = NULL; -int beam_width = 500; +bool set_beamwidth = false; + +int beam_width = 0; bool set_alphabeta = false; @@ -98,6 +100,7 @@ bool ProcessArgs(int argc, char** argv) break; case 'b': + set_beamwidth = true; beam_width = atoi(optarg); break; diff --git a/native_client/client.cc b/native_client/client.cc index 718fba75..abcadd8d 100644 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -368,14 +368,22 @@ main(int argc, char **argv) // Initialise DeepSpeech ModelState* ctx; - int status = DS_CreateModel(model, beam_width, &ctx); + int status = DS_CreateModel(model, &ctx); if (status != 0) { fprintf(stderr, "Could not create model.\n"); return 1; } + if (set_beamwidth) { + status = DS_SetModelBeamWidth(ctx, beam_width); + if (status != 0) { + fprintf(stderr, "Could not set model beam width.\n"); + return 1; + } + } + if (scorer) { - int status = DS_EnableExternalScorer(ctx, scorer); + status = DS_EnableExternalScorer(ctx, scorer); if (status != 0) { fprintf(stderr, "Could not enable external scorer.\n"); return 1; diff --git a/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs b/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs index e5e33370..15c2212c 100644 --- a/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs +++ b/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs @@ -19,11 +19,10 @@ namespace DeepSpeechClient /// Initializes a new instance of class and creates a new acoustic model. /// /// The path to the frozen model graph. - /// The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time. /// Thrown when the native binary failed to create the model. - public DeepSpeech(string aModelPath, uint aBeamWidth) + public DeepSpeech(string aModelPath) { - CreateModel(aModelPath, aBeamWidth); + CreateModel(aModelPath); } #region IDeepSpeech @@ -32,10 +31,8 @@ namespace DeepSpeechClient /// Create an object providing an interface to a trained DeepSpeech model. /// /// The path to the frozen model graph. - /// The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time. /// Thrown when the native binary failed to create the model. - private unsafe void CreateModel(string aModelPath, - uint aBeamWidth) + private unsafe void CreateModel(string aModelPath) { string exceptionMessage = null; if (string.IsNullOrWhiteSpace(aModelPath)) @@ -52,11 +49,31 @@ namespace DeepSpeechClient throw new FileNotFoundException(exceptionMessage); } var resultCode = NativeImp.DS_CreateModel(aModelPath, - aBeamWidth, ref _modelStatePP); EvaluateResultCode(resultCode); } + /// + /// Get beam width value used by the model. If SetModelBeamWidth was not + /// called before, will return the default value loaded from the model file. + /// + /// Beam width value used by the model. + public unsafe uint GetModelBeamWidth() + { + return NativeImp.DS_GetModelBeamWidth(_modelStatePP); + } + + /// + /// Set beam width value used by the model. + /// + /// The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time. + /// Thrown on failure. + public unsafe void SetModelBeamWidth(uint aBeamWidth) + { + var resultCode = NativeImp.DS_SetModelBeamWidth(_modelStatePP, aBeamWidth); + EvaluateResultCode(resultCode); + } + /// /// Return the sample rate expected by the model. /// diff --git a/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs b/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs index ecbfb7e9..f00c188d 100644 --- a/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs +++ b/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs @@ -20,6 +20,21 @@ namespace DeepSpeechClient.Interfaces /// Sample rate. unsafe int GetModelSampleRate(); + /// + /// Get beam width value used by the model. If SetModelBeamWidth was not + /// called before, will return the default value loaded from the model + /// file. + /// + /// Beam width value used by the model. + unsafe uint GetModelBeamWidth(); + + /// + /// Set beam width value used by the model. + /// + /// The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time. + /// Thrown on failure. + unsafe void SetModelBeamWidth(uint aBeamWidth); + /// /// Enable decoding using an external scorer. /// diff --git a/native_client/dotnet/DeepSpeechClient/NativeImp.cs b/native_client/dotnet/DeepSpeechClient/NativeImp.cs index 1c49feec..af28618c 100644 --- a/native_client/dotnet/DeepSpeechClient/NativeImp.cs +++ b/native_client/dotnet/DeepSpeechClient/NativeImp.cs @@ -14,6 +14,17 @@ namespace DeepSpeechClient [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal static extern void DS_PrintVersions(); + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] + internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath, + ref IntPtr** pint); + + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] + internal unsafe static extern uint DS_GetModelBeamWidth(IntPtr** aCtx); + + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] + internal unsafe static extern ErrorCodes DS_SetModelBeamWidth(IntPtr** aCtx, + uint aBeamWidth); + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath, uint aBeamWidth, diff --git a/native_client/dotnet/DeepSpeechConsole/Program.cs b/native_client/dotnet/DeepSpeechConsole/Program.cs index 1f6e299b..b35c7046 100644 --- a/native_client/dotnet/DeepSpeechConsole/Program.cs +++ b/native_client/dotnet/DeepSpeechConsole/Program.cs @@ -46,15 +46,12 @@ namespace CSharpExamples extended = !string.IsNullOrWhiteSpace(GetArgument(args, "--extended")); } - const uint BEAM_WIDTH = 500; - Stopwatch stopwatch = new Stopwatch(); try { Console.WriteLine("Loading model..."); stopwatch.Start(); - using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm", - BEAM_WIDTH)) + using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm")) { stopwatch.Stop(); diff --git a/native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java b/native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java index 12e758df..f9d9a11e 100644 --- a/native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java +++ b/native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java @@ -49,7 +49,8 @@ public class DeepSpeechActivity extends AppCompatActivity { private void newModel(String tfliteModel) { this._tfliteStatus.setText("Creating model"); if (this._m == null) { - this._m = new DeepSpeechModel(tfliteModel, BEAM_WIDTH); + this._m = new DeepSpeechModel(tfliteModel); + this._m.setBeamWidth(BEAM_WIDTH); } } diff --git a/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java b/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java index bb6bbe42..60d21256 100644 --- a/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java +++ b/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java @@ -59,7 +59,7 @@ public class BasicTest { @Test public void loadDeepSpeech_basic() { - DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); + DeepSpeechModel m = new DeepSpeechModel(modelFile); m.freeModel(); } @@ -116,7 +116,8 @@ public class BasicTest { @Test public void loadDeepSpeech_stt_noLM() { - DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); + DeepSpeechModel m = new DeepSpeechModel(modelFile); + m.setBeamWidth(BEAM_WIDTH); String decoded = doSTT(m, false); assertEquals("she had your dark suit in greasy wash water all year", decoded); @@ -125,7 +126,8 @@ public class BasicTest { @Test public void loadDeepSpeech_stt_withLM() { - DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); + DeepSpeechModel m = new DeepSpeechModel(modelFile); + m.setBeamWidth(BEAM_WIDTH); m.enableExternalScorer(scorerFile); String decoded = doSTT(m, false); @@ -135,7 +137,8 @@ public class BasicTest { @Test public void loadDeepSpeech_sttWithMetadata_noLM() { - DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); + DeepSpeechModel m = new DeepSpeechModel(modelFile); + m.setBeamWidth(BEAM_WIDTH); String decoded = doSTT(m, true); assertEquals("she had your dark suit in greasy wash water all year", decoded); @@ -144,7 +147,8 @@ public class BasicTest { @Test public void loadDeepSpeech_sttWithMetadata_withLM() { - DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); + DeepSpeechModel m = new DeepSpeechModel(modelFile); + m.setBeamWidth(BEAM_WIDTH); m.enableExternalScorer(scorerFile); String decoded = doSTT(m, true); diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java index 0438ac10..1c26e2f9 100644 --- a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java @@ -20,16 +20,35 @@ public class DeepSpeechModel { * @constructor * * @param modelPath The path to the frozen model graph. - * @param beam_width The beam width used by the decoder. A larger beam - * width generates better results at the cost of decoding - * time. */ - public DeepSpeechModel(String modelPath, int beam_width) { + public DeepSpeechModel(String modelPath) { this._mspp = impl.new_modelstatep(); - impl.CreateModel(modelPath, beam_width, this._mspp); + impl.CreateModel(modelPath, this._mspp); this._msp = impl.modelstatep_value(this._mspp); } + /** + * @brief Get beam width value used by the model. If setModelBeamWidth was not + * called before, will return the default value loaded from the model file. + * + * @return Beam width value used by the model. + */ + public int beamWidth() { + return impl.GetModelBeamWidth(this._msp); + } + + /** + * @brief Set beam width value used by the model. + * + * @param aBeamWidth The beam width used by the model. A larger beam width value + * generates better results at the cost of decoding time. + * + * @return Zero on success, non-zero on failure. + */ + public int setBeamWidth(int beamWidth) { + return impl.SetModelBeamWidth(this._msp, beamWidth); + } + /** * @brief Return the sample rate expected by the model. * diff --git a/native_client/javascript/client.js b/native_client/javascript/client.js index 7266b85d..09406ccc 100644 --- a/native_client/javascript/client.js +++ b/native_client/javascript/client.js @@ -31,7 +31,7 @@ var parser = new argparse.ArgumentParser({addHelp: true, description: 'Running D parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'}); parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'}); parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'}); -parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', defaultValue: 500, type: 'int'}); +parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', type: 'int'}); parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'}); parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'}); parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'}); @@ -53,10 +53,14 @@ function metadataToString(metadata) { console.error('Loading model from file %s', args['model']); const model_load_start = process.hrtime(); -var model = new Ds.Model(args['model'], args['beam_width']); +var model = new Ds.Model(args['model']); const model_load_end = process.hrtime(model_load_start); console.error('Loaded model in %ds.', totalTime(model_load_end)); +if (args['beam_width']) { + model.setBeamWidth(args['beam_width']); +} + var desired_sample_rate = model.sampleRate(); if (args['scorer']) { diff --git a/native_client/javascript/index.js b/native_client/javascript/index.js index 772b1a82..38ecbf0a 100644 --- a/native_client/javascript/index.js +++ b/native_client/javascript/index.js @@ -25,14 +25,13 @@ if (process.platform === 'win32') { * An object providing an interface to a trained DeepSpeech model. * * @param {string} aModelPath The path to the frozen model graph. - * @param {number} aBeamWidth The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time. * * @throws on error */ -function Model() { +function Model(aModelPath) { this._impl = null; - const rets = binding.CreateModel.apply(null, arguments); + const rets = binding.CreateModel(aModelPath); const status = rets[0]; const impl = rets[1]; if (status !== 0) { @@ -42,6 +41,38 @@ function Model() { this._impl = impl; } +/** + * Get beam width value used by the model. If {@link DS_SetModelBeamWidth} + * was not called before, will return the default value loaded from the + * model file. + * + * @return {number} Beam width value used by the model. + */ +Model.prototype.beamWidth = function() { + return binding.GetModelBeamWidth(this._impl); +} + +/** + * Set beam width value used by the model. + * + * @param {number} The beam width used by the model. A larger beam width value generates better results at the cost of decoding time. + * + * @return {number} Zero on success, non-zero on failure. + */ +Model.prototype.setBeamWidth = function(aBeamWidth) { + return binding.SetModelBeamWidth(this._impl, aBeamWidth); +} + +/** + * Return the sample rate expected by the model. + * + * @return {number} Sample rate. + */ +Model.prototype.beamWidth = function() { + return binding.GetModelBeamWidth(this._impl); +} + + /** * Return the sample rate expected by the model. * diff --git a/native_client/python/__init__.py b/native_client/python/__init__.py index ccb53fc4..855a6eeb 100644 --- a/native_client/python/__init__.py +++ b/native_client/python/__init__.py @@ -28,15 +28,12 @@ class Model(object): :param aModelPath: Path to model file to load :type aModelPath: str - - :param aBeamWidth: Decoder beam width - :type aBeamWidth: int """ - def __init__(self, *args, **kwargs): + def __init__(self, model_path): # make sure the attribute is there if CreateModel fails self._impl = None - status, impl = deepspeech.impl.CreateModel(*args, **kwargs) + status, impl = deepspeech.impl.CreateModel(model_path) if status != 0: raise RuntimeError("CreateModel failed with error code {}".format(status)) self._impl = impl @@ -46,6 +43,29 @@ class Model(object): deepspeech.impl.FreeModel(self._impl) self._impl = None + def beamWidth(self): + """ + Get beam width value used by the model. If {@link DS_SetModelBeamWidth} + was not called before, will return the default value loaded from the + model file. + + :return: Beam width value used by the model. + :type: int + """ + return deepspeech.impl.GetModelBeamWidth(self._impl) + + def setBeamWidth(self, beam_width): + """ + Set beam width value used by the model. + + :param beam_width: The beam width used by the model. A larger beam width value generates better results at the cost of decoding time. + :type beam_width: int + + :return: Zero on success, non-zero on failure. + :type: int + """ + return deepspeech.impl.SetModelBeamWidth(self._impl, beam_width) + def sampleRate(self): """ Return the sample rate expected by the model. diff --git a/native_client/python/client.py b/native_client/python/client.py index 2ef88caf..26db1e00 100644 --- a/native_client/python/client.py +++ b/native_client/python/client.py @@ -92,7 +92,7 @@ def main(): help='Path to the external scorer file') parser.add_argument('--audio', required=True, help='Path to the audio file to run (WAV format)') - parser.add_argument('--beam_width', type=int, default=500, + parser.add_argument('--beam_width', type=int, help='Beam width for the CTC decoder') parser.add_argument('--lm_alpha', type=float, help='Language model weight (lm_alpha). If not specified, use default from the scorer package.') @@ -108,10 +108,13 @@ def main(): print('Loading model from file {}'.format(args.model), file=sys.stderr) model_load_start = timer() - ds = Model(args.model, args.beam_width) + ds = Model(args.model) model_load_end = timer() - model_load_start print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr) + if args.beam_width: + ds.setModelBeamWidth(args.beam_width) + desired_sample_rate = ds.sampleRate() if args.scorer: diff --git a/native_client/test/concurrent_streams.py b/native_client/test/concurrent_streams.py index d799de36..e435b43f 100644 --- a/native_client/test/concurrent_streams.py +++ b/native_client/test/concurrent_streams.py @@ -9,12 +9,6 @@ import wave from deepspeech import Model -# These constants control the beam search decoder - -# Beam width used in the CTC decoder when building candidate transcriptions -BEAM_WIDTH = 500 - - def main(): parser = argparse.ArgumentParser(description='Running DeepSpeech inference.') parser.add_argument('--model', required=True, @@ -27,7 +21,7 @@ def main(): help='Second audio file to use in interleaved streams') args = parser.parse_args() - ds = Model(args.model, BEAM_WIDTH) + ds = Model(args.model) if args.scorer: ds.enableExternalScorer(args.scorer)