Fix consumers of DS_CreateModel

This commit is contained in:
Reuben Morais 2020-01-27 16:29:16 +01:00
Родитель 8e9b6ef7b3
Коммит c512383aec
16 изменённых файлов: 174 добавлений и 46 удалений

Просмотреть файл

@ -877,7 +877,7 @@ def package_zip():
json.dump({ json.dump({
'name': FLAGS.export_language, 'name': FLAGS.export_language,
'parameters': { 'parameters': {
'beamWidth': FLAGS.beam_width, 'beamWidth': FLAGS.export_beam_width,
'lmAlpha': FLAGS.lm_alpha, 'lmAlpha': FLAGS.lm_alpha,
'lmBeta': FLAGS.lm_beta 'lmBeta': FLAGS.lm_beta
} }

Просмотреть файл

@ -36,7 +36,8 @@ LM_BETA = 1.85
def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask): def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask) os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask)
ds = Model(model, BEAM_WIDTH) ds = Model(model)
ds.setBeamWidth(BEAM_WIDTH)
ds.enableExternalScorer(scorer) ds.enableExternalScorer(scorer)
ds.setScorerAlphaBeta(LM_ALPHA, LM_BETA) ds.setScorerAlphaBeta(LM_ALPHA, LM_BETA)

Просмотреть файл

@ -16,7 +16,9 @@ char* scorer = NULL;
char* audio = NULL; char* audio = NULL;
int beam_width = 500; bool set_beamwidth = false;
int beam_width = 0;
bool set_alphabeta = false; bool set_alphabeta = false;
@ -98,6 +100,7 @@ bool ProcessArgs(int argc, char** argv)
break; break;
case 'b': case 'b':
set_beamwidth = true;
beam_width = atoi(optarg); beam_width = atoi(optarg);
break; break;

Просмотреть файл

@ -368,14 +368,22 @@ main(int argc, char **argv)
// Initialise DeepSpeech // Initialise DeepSpeech
ModelState* ctx; ModelState* ctx;
int status = DS_CreateModel(model, beam_width, &ctx); int status = DS_CreateModel(model, &ctx);
if (status != 0) { if (status != 0) {
fprintf(stderr, "Could not create model.\n"); fprintf(stderr, "Could not create model.\n");
return 1; return 1;
} }
if (set_beamwidth) {
status = DS_SetModelBeamWidth(ctx, beam_width);
if (status != 0) {
fprintf(stderr, "Could not set model beam width.\n");
return 1;
}
}
if (scorer) { if (scorer) {
int status = DS_EnableExternalScorer(ctx, scorer); status = DS_EnableExternalScorer(ctx, scorer);
if (status != 0) { if (status != 0) {
fprintf(stderr, "Could not enable external scorer.\n"); fprintf(stderr, "Could not enable external scorer.\n");
return 1; return 1;

Просмотреть файл

@ -19,11 +19,10 @@ namespace DeepSpeechClient
/// Initializes a new instance of <see cref="DeepSpeech"/> class and creates a new acoustic model. /// Initializes a new instance of <see cref="DeepSpeech"/> class and creates a new acoustic model.
/// </summary> /// </summary>
/// <param name="aModelPath">The path to the frozen model graph.</param> /// <param name="aModelPath">The path to the frozen model graph.</param>
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception> /// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
public DeepSpeech(string aModelPath, uint aBeamWidth) public DeepSpeech(string aModelPath)
{ {
CreateModel(aModelPath, aBeamWidth); CreateModel(aModelPath);
} }
#region IDeepSpeech #region IDeepSpeech
@ -32,10 +31,8 @@ namespace DeepSpeechClient
/// Create an object providing an interface to a trained DeepSpeech model. /// Create an object providing an interface to a trained DeepSpeech model.
/// </summary> /// </summary>
/// <param name="aModelPath">The path to the frozen model graph.</param> /// <param name="aModelPath">The path to the frozen model graph.</param>
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception> /// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
private unsafe void CreateModel(string aModelPath, private unsafe void CreateModel(string aModelPath)
uint aBeamWidth)
{ {
string exceptionMessage = null; string exceptionMessage = null;
if (string.IsNullOrWhiteSpace(aModelPath)) if (string.IsNullOrWhiteSpace(aModelPath))
@ -52,11 +49,31 @@ namespace DeepSpeechClient
throw new FileNotFoundException(exceptionMessage); throw new FileNotFoundException(exceptionMessage);
} }
var resultCode = NativeImp.DS_CreateModel(aModelPath, var resultCode = NativeImp.DS_CreateModel(aModelPath,
aBeamWidth,
ref _modelStatePP); ref _modelStatePP);
EvaluateResultCode(resultCode); EvaluateResultCode(resultCode);
} }
/// <summary>
/// Get beam width value used by the model. If SetModelBeamWidth was not
/// called before, will return the default value loaded from the model file.
/// </summary>
/// <returns>Beam width value used by the model.</returns>
public unsafe uint GetModelBeamWidth()
{
return NativeImp.DS_GetModelBeamWidth(_modelStatePP);
}
/// <summary>
/// Set beam width value used by the model.
/// </summary>
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time.</param>
/// <exception cref="ArgumentException">Thrown on failure.</exception>
public unsafe void SetModelBeamWidth(uint aBeamWidth)
{
var resultCode = NativeImp.DS_SetModelBeamWidth(_modelStatePP, aBeamWidth);
EvaluateResultCode(resultCode);
}
/// <summary> /// <summary>
/// Return the sample rate expected by the model. /// Return the sample rate expected by the model.
/// </summary> /// </summary>

Просмотреть файл

@ -20,6 +20,21 @@ namespace DeepSpeechClient.Interfaces
/// <returns>Sample rate.</returns> /// <returns>Sample rate.</returns>
unsafe int GetModelSampleRate(); unsafe int GetModelSampleRate();
/// <summary>
/// Get beam width value used by the model. If SetModelBeamWidth was not
/// called before, will return the default value loaded from the model
/// file.
/// </summary>
/// <returns>Beam width value used by the model.</returns>
unsafe uint GetModelBeamWidth();
/// <summary>
/// Set beam width value used by the model.
/// </summary>
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time.</param>
/// <exception cref="ArgumentException">Thrown on failure.</exception>
unsafe void SetModelBeamWidth(uint aBeamWidth);
/// <summary> /// <summary>
/// Enable decoding using an external scorer. /// Enable decoding using an external scorer.
/// </summary> /// </summary>

Просмотреть файл

@ -14,6 +14,17 @@ namespace DeepSpeechClient
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static extern void DS_PrintVersions(); internal static extern void DS_PrintVersions();
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
ref IntPtr** pint);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal unsafe static extern uint DS_GetModelBeamWidth(IntPtr** aCtx);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal unsafe static extern ErrorCodes DS_SetModelBeamWidth(IntPtr** aCtx,
uint aBeamWidth);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath, internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
uint aBeamWidth, uint aBeamWidth,

Просмотреть файл

@ -46,15 +46,12 @@ namespace CSharpExamples
extended = !string.IsNullOrWhiteSpace(GetArgument(args, "--extended")); extended = !string.IsNullOrWhiteSpace(GetArgument(args, "--extended"));
} }
const uint BEAM_WIDTH = 500;
Stopwatch stopwatch = new Stopwatch(); Stopwatch stopwatch = new Stopwatch();
try try
{ {
Console.WriteLine("Loading model..."); Console.WriteLine("Loading model...");
stopwatch.Start(); stopwatch.Start();
using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm", using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm"))
BEAM_WIDTH))
{ {
stopwatch.Stop(); stopwatch.Stop();

Просмотреть файл

@ -49,7 +49,8 @@ public class DeepSpeechActivity extends AppCompatActivity {
private void newModel(String tfliteModel) { private void newModel(String tfliteModel) {
this._tfliteStatus.setText("Creating model"); this._tfliteStatus.setText("Creating model");
if (this._m == null) { if (this._m == null) {
this._m = new DeepSpeechModel(tfliteModel, BEAM_WIDTH); this._m = new DeepSpeechModel(tfliteModel);
this._m.setBeamWidth(BEAM_WIDTH);
} }
} }

Просмотреть файл

@ -59,7 +59,7 @@ public class BasicTest {
@Test @Test
public void loadDeepSpeech_basic() { public void loadDeepSpeech_basic() {
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); DeepSpeechModel m = new DeepSpeechModel(modelFile);
m.freeModel(); m.freeModel();
} }
@ -116,7 +116,8 @@ public class BasicTest {
@Test @Test
public void loadDeepSpeech_stt_noLM() { public void loadDeepSpeech_stt_noLM() {
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); DeepSpeechModel m = new DeepSpeechModel(modelFile);
m.setBeamWidth(BEAM_WIDTH);
String decoded = doSTT(m, false); String decoded = doSTT(m, false);
assertEquals("she had your dark suit in greasy wash water all year", decoded); assertEquals("she had your dark suit in greasy wash water all year", decoded);
@ -125,7 +126,8 @@ public class BasicTest {
@Test @Test
public void loadDeepSpeech_stt_withLM() { public void loadDeepSpeech_stt_withLM() {
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); DeepSpeechModel m = new DeepSpeechModel(modelFile);
m.setBeamWidth(BEAM_WIDTH);
m.enableExternalScorer(scorerFile); m.enableExternalScorer(scorerFile);
String decoded = doSTT(m, false); String decoded = doSTT(m, false);
@ -135,7 +137,8 @@ public class BasicTest {
@Test @Test
public void loadDeepSpeech_sttWithMetadata_noLM() { public void loadDeepSpeech_sttWithMetadata_noLM() {
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); DeepSpeechModel m = new DeepSpeechModel(modelFile);
m.setBeamWidth(BEAM_WIDTH);
String decoded = doSTT(m, true); String decoded = doSTT(m, true);
assertEquals("she had your dark suit in greasy wash water all year", decoded); assertEquals("she had your dark suit in greasy wash water all year", decoded);
@ -144,7 +147,8 @@ public class BasicTest {
@Test @Test
public void loadDeepSpeech_sttWithMetadata_withLM() { public void loadDeepSpeech_sttWithMetadata_withLM() {
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); DeepSpeechModel m = new DeepSpeechModel(modelFile);
m.setBeamWidth(BEAM_WIDTH);
m.enableExternalScorer(scorerFile); m.enableExternalScorer(scorerFile);
String decoded = doSTT(m, true); String decoded = doSTT(m, true);

Просмотреть файл

@ -20,16 +20,35 @@ public class DeepSpeechModel {
* @constructor * @constructor
* *
* @param modelPath The path to the frozen model graph. * @param modelPath The path to the frozen model graph.
* @param beam_width The beam width used by the decoder. A larger beam
* width generates better results at the cost of decoding
* time.
*/ */
public DeepSpeechModel(String modelPath, int beam_width) { public DeepSpeechModel(String modelPath) {
this._mspp = impl.new_modelstatep(); this._mspp = impl.new_modelstatep();
impl.CreateModel(modelPath, beam_width, this._mspp); impl.CreateModel(modelPath, this._mspp);
this._msp = impl.modelstatep_value(this._mspp); this._msp = impl.modelstatep_value(this._mspp);
} }
/**
* @brief Get beam width value used by the model. If setModelBeamWidth was not
* called before, will return the default value loaded from the model file.
*
* @return Beam width value used by the model.
*/
public int beamWidth() {
return impl.GetModelBeamWidth(this._msp);
}
/**
* @brief Set beam width value used by the model.
*
* @param aBeamWidth The beam width used by the model. A larger beam width value
* generates better results at the cost of decoding time.
*
* @return Zero on success, non-zero on failure.
*/
public int setBeamWidth(int beamWidth) {
return impl.SetModelBeamWidth(this._msp, beamWidth);
}
/** /**
* @brief Return the sample rate expected by the model. * @brief Return the sample rate expected by the model.
* *

Просмотреть файл

@ -31,7 +31,7 @@ var parser = new argparse.ArgumentParser({addHelp: true, description: 'Running D
parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'}); parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'});
parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'}); parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'});
parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'}); parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'});
parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', defaultValue: 500, type: 'int'}); parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', type: 'int'});
parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'}); parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'});
parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'}); parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'});
parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'}); parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'});
@ -53,10 +53,14 @@ function metadataToString(metadata) {
console.error('Loading model from file %s', args['model']); console.error('Loading model from file %s', args['model']);
const model_load_start = process.hrtime(); const model_load_start = process.hrtime();
var model = new Ds.Model(args['model'], args['beam_width']); var model = new Ds.Model(args['model']);
const model_load_end = process.hrtime(model_load_start); const model_load_end = process.hrtime(model_load_start);
console.error('Loaded model in %ds.', totalTime(model_load_end)); console.error('Loaded model in %ds.', totalTime(model_load_end));
if (args['beam_width']) {
model.setBeamWidth(args['beam_width']);
}
var desired_sample_rate = model.sampleRate(); var desired_sample_rate = model.sampleRate();
if (args['scorer']) { if (args['scorer']) {

Просмотреть файл

@ -25,14 +25,13 @@ if (process.platform === 'win32') {
* An object providing an interface to a trained DeepSpeech model. * An object providing an interface to a trained DeepSpeech model.
* *
* @param {string} aModelPath The path to the frozen model graph. * @param {string} aModelPath The path to the frozen model graph.
* @param {number} aBeamWidth The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.
* *
* @throws on error * @throws on error
*/ */
function Model() { function Model(aModelPath) {
this._impl = null; this._impl = null;
const rets = binding.CreateModel.apply(null, arguments); const rets = binding.CreateModel(aModelPath);
const status = rets[0]; const status = rets[0];
const impl = rets[1]; const impl = rets[1];
if (status !== 0) { if (status !== 0) {
@ -42,6 +41,38 @@ function Model() {
this._impl = impl; this._impl = impl;
} }
/**
* Get beam width value used by the model. If {@link DS_SetModelBeamWidth}
* was not called before, will return the default value loaded from the
* model file.
*
* @return {number} Beam width value used by the model.
*/
Model.prototype.beamWidth = function() {
return binding.GetModelBeamWidth(this._impl);
}
/**
* Set beam width value used by the model.
*
* @param {number} The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
*
* @return {number} Zero on success, non-zero on failure.
*/
Model.prototype.setBeamWidth = function(aBeamWidth) {
return binding.SetModelBeamWidth(this._impl, aBeamWidth);
}
/**
* Return the sample rate expected by the model.
*
* @return {number} Sample rate.
*/
Model.prototype.beamWidth = function() {
return binding.GetModelBeamWidth(this._impl);
}
/** /**
* Return the sample rate expected by the model. * Return the sample rate expected by the model.
* *

Просмотреть файл

@ -28,15 +28,12 @@ class Model(object):
:param aModelPath: Path to model file to load :param aModelPath: Path to model file to load
:type aModelPath: str :type aModelPath: str
:param aBeamWidth: Decoder beam width
:type aBeamWidth: int
""" """
def __init__(self, *args, **kwargs): def __init__(self, model_path):
# make sure the attribute is there if CreateModel fails # make sure the attribute is there if CreateModel fails
self._impl = None self._impl = None
status, impl = deepspeech.impl.CreateModel(*args, **kwargs) status, impl = deepspeech.impl.CreateModel(model_path)
if status != 0: if status != 0:
raise RuntimeError("CreateModel failed with error code {}".format(status)) raise RuntimeError("CreateModel failed with error code {}".format(status))
self._impl = impl self._impl = impl
@ -46,6 +43,29 @@ class Model(object):
deepspeech.impl.FreeModel(self._impl) deepspeech.impl.FreeModel(self._impl)
self._impl = None self._impl = None
def beamWidth(self):
"""
Get beam width value used by the model. If {@link DS_SetModelBeamWidth}
was not called before, will return the default value loaded from the
model file.
:return: Beam width value used by the model.
:type: int
"""
return deepspeech.impl.GetModelBeamWidth(self._impl)
def setBeamWidth(self, beam_width):
"""
Set beam width value used by the model.
:param beam_width: The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
:type beam_width: int
:return: Zero on success, non-zero on failure.
:type: int
"""
return deepspeech.impl.SetModelBeamWidth(self._impl, beam_width)
def sampleRate(self): def sampleRate(self):
""" """
Return the sample rate expected by the model. Return the sample rate expected by the model.

Просмотреть файл

@ -92,7 +92,7 @@ def main():
help='Path to the external scorer file') help='Path to the external scorer file')
parser.add_argument('--audio', required=True, parser.add_argument('--audio', required=True,
help='Path to the audio file to run (WAV format)') help='Path to the audio file to run (WAV format)')
parser.add_argument('--beam_width', type=int, default=500, parser.add_argument('--beam_width', type=int,
help='Beam width for the CTC decoder') help='Beam width for the CTC decoder')
parser.add_argument('--lm_alpha', type=float, parser.add_argument('--lm_alpha', type=float,
help='Language model weight (lm_alpha). If not specified, use default from the scorer package.') help='Language model weight (lm_alpha). If not specified, use default from the scorer package.')
@ -108,10 +108,13 @@ def main():
print('Loading model from file {}'.format(args.model), file=sys.stderr) print('Loading model from file {}'.format(args.model), file=sys.stderr)
model_load_start = timer() model_load_start = timer()
ds = Model(args.model, args.beam_width) ds = Model(args.model)
model_load_end = timer() - model_load_start model_load_end = timer() - model_load_start
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr) print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
if args.beam_width:
ds.setModelBeamWidth(args.beam_width)
desired_sample_rate = ds.sampleRate() desired_sample_rate = ds.sampleRate()
if args.scorer: if args.scorer:

Просмотреть файл

@ -9,12 +9,6 @@ import wave
from deepspeech import Model from deepspeech import Model
# These constants control the beam search decoder
# Beam width used in the CTC decoder when building candidate transcriptions
BEAM_WIDTH = 500
def main(): def main():
parser = argparse.ArgumentParser(description='Running DeepSpeech inference.') parser = argparse.ArgumentParser(description='Running DeepSpeech inference.')
parser.add_argument('--model', required=True, parser.add_argument('--model', required=True,
@ -27,7 +21,7 @@ def main():
help='Second audio file to use in interleaved streams') help='Second audio file to use in interleaved streams')
args = parser.parse_args() args = parser.parse_args()
ds = Model(args.model, BEAM_WIDTH) ds = Model(args.model)
if args.scorer: if args.scorer:
ds.enableExternalScorer(args.scorer) ds.enableExternalScorer(args.scorer)