зеркало из https://github.com/mozilla/DeepSpeech.git
Fix consumers of DS_CreateModel
This commit is contained in:
Родитель
8e9b6ef7b3
Коммит
c512383aec
|
@ -877,7 +877,7 @@ def package_zip():
|
||||||
json.dump({
|
json.dump({
|
||||||
'name': FLAGS.export_language,
|
'name': FLAGS.export_language,
|
||||||
'parameters': {
|
'parameters': {
|
||||||
'beamWidth': FLAGS.beam_width,
|
'beamWidth': FLAGS.export_beam_width,
|
||||||
'lmAlpha': FLAGS.lm_alpha,
|
'lmAlpha': FLAGS.lm_alpha,
|
||||||
'lmBeta': FLAGS.lm_beta
|
'lmBeta': FLAGS.lm_beta
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,8 @@ LM_BETA = 1.85
|
||||||
|
|
||||||
def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
|
def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask)
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask)
|
||||||
ds = Model(model, BEAM_WIDTH)
|
ds = Model(model)
|
||||||
|
ds.setBeamWidth(BEAM_WIDTH)
|
||||||
ds.enableExternalScorer(scorer)
|
ds.enableExternalScorer(scorer)
|
||||||
ds.setScorerAlphaBeta(LM_ALPHA, LM_BETA)
|
ds.setScorerAlphaBeta(LM_ALPHA, LM_BETA)
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,9 @@ char* scorer = NULL;
|
||||||
|
|
||||||
char* audio = NULL;
|
char* audio = NULL;
|
||||||
|
|
||||||
int beam_width = 500;
|
bool set_beamwidth = false;
|
||||||
|
|
||||||
|
int beam_width = 0;
|
||||||
|
|
||||||
bool set_alphabeta = false;
|
bool set_alphabeta = false;
|
||||||
|
|
||||||
|
@ -98,6 +100,7 @@ bool ProcessArgs(int argc, char** argv)
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case 'b':
|
case 'b':
|
||||||
|
set_beamwidth = true;
|
||||||
beam_width = atoi(optarg);
|
beam_width = atoi(optarg);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
|
|
@ -368,14 +368,22 @@ main(int argc, char **argv)
|
||||||
|
|
||||||
// Initialise DeepSpeech
|
// Initialise DeepSpeech
|
||||||
ModelState* ctx;
|
ModelState* ctx;
|
||||||
int status = DS_CreateModel(model, beam_width, &ctx);
|
int status = DS_CreateModel(model, &ctx);
|
||||||
if (status != 0) {
|
if (status != 0) {
|
||||||
fprintf(stderr, "Could not create model.\n");
|
fprintf(stderr, "Could not create model.\n");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (set_beamwidth) {
|
||||||
|
status = DS_SetModelBeamWidth(ctx, beam_width);
|
||||||
|
if (status != 0) {
|
||||||
|
fprintf(stderr, "Could not set model beam width.\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (scorer) {
|
if (scorer) {
|
||||||
int status = DS_EnableExternalScorer(ctx, scorer);
|
status = DS_EnableExternalScorer(ctx, scorer);
|
||||||
if (status != 0) {
|
if (status != 0) {
|
||||||
fprintf(stderr, "Could not enable external scorer.\n");
|
fprintf(stderr, "Could not enable external scorer.\n");
|
||||||
return 1;
|
return 1;
|
||||||
|
|
|
@ -19,11 +19,10 @@ namespace DeepSpeechClient
|
||||||
/// Initializes a new instance of <see cref="DeepSpeech"/> class and creates a new acoustic model.
|
/// Initializes a new instance of <see cref="DeepSpeech"/> class and creates a new acoustic model.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="aModelPath">The path to the frozen model graph.</param>
|
/// <param name="aModelPath">The path to the frozen model graph.</param>
|
||||||
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
|
|
||||||
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
|
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
|
||||||
public DeepSpeech(string aModelPath, uint aBeamWidth)
|
public DeepSpeech(string aModelPath)
|
||||||
{
|
{
|
||||||
CreateModel(aModelPath, aBeamWidth);
|
CreateModel(aModelPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
#region IDeepSpeech
|
#region IDeepSpeech
|
||||||
|
@ -32,10 +31,8 @@ namespace DeepSpeechClient
|
||||||
/// Create an object providing an interface to a trained DeepSpeech model.
|
/// Create an object providing an interface to a trained DeepSpeech model.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="aModelPath">The path to the frozen model graph.</param>
|
/// <param name="aModelPath">The path to the frozen model graph.</param>
|
||||||
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.</param>
|
|
||||||
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
|
/// <exception cref="ArgumentException">Thrown when the native binary failed to create the model.</exception>
|
||||||
private unsafe void CreateModel(string aModelPath,
|
private unsafe void CreateModel(string aModelPath)
|
||||||
uint aBeamWidth)
|
|
||||||
{
|
{
|
||||||
string exceptionMessage = null;
|
string exceptionMessage = null;
|
||||||
if (string.IsNullOrWhiteSpace(aModelPath))
|
if (string.IsNullOrWhiteSpace(aModelPath))
|
||||||
|
@ -52,11 +49,31 @@ namespace DeepSpeechClient
|
||||||
throw new FileNotFoundException(exceptionMessage);
|
throw new FileNotFoundException(exceptionMessage);
|
||||||
}
|
}
|
||||||
var resultCode = NativeImp.DS_CreateModel(aModelPath,
|
var resultCode = NativeImp.DS_CreateModel(aModelPath,
|
||||||
aBeamWidth,
|
|
||||||
ref _modelStatePP);
|
ref _modelStatePP);
|
||||||
EvaluateResultCode(resultCode);
|
EvaluateResultCode(resultCode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Get beam width value used by the model. If SetModelBeamWidth was not
|
||||||
|
/// called before, will return the default value loaded from the model file.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Beam width value used by the model.</returns>
|
||||||
|
public unsafe uint GetModelBeamWidth()
|
||||||
|
{
|
||||||
|
return NativeImp.DS_GetModelBeamWidth(_modelStatePP);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Set beam width value used by the model.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time.</param>
|
||||||
|
/// <exception cref="ArgumentException">Thrown on failure.</exception>
|
||||||
|
public unsafe void SetModelBeamWidth(uint aBeamWidth)
|
||||||
|
{
|
||||||
|
var resultCode = NativeImp.DS_SetModelBeamWidth(_modelStatePP, aBeamWidth);
|
||||||
|
EvaluateResultCode(resultCode);
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Return the sample rate expected by the model.
|
/// Return the sample rate expected by the model.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
|
|
@ -20,6 +20,21 @@ namespace DeepSpeechClient.Interfaces
|
||||||
/// <returns>Sample rate.</returns>
|
/// <returns>Sample rate.</returns>
|
||||||
unsafe int GetModelSampleRate();
|
unsafe int GetModelSampleRate();
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Get beam width value used by the model. If SetModelBeamWidth was not
|
||||||
|
/// called before, will return the default value loaded from the model
|
||||||
|
/// file.
|
||||||
|
/// </summary>
|
||||||
|
/// <returns>Beam width value used by the model.</returns>
|
||||||
|
unsafe uint GetModelBeamWidth();
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Set beam width value used by the model.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="aBeamWidth">The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time.</param>
|
||||||
|
/// <exception cref="ArgumentException">Thrown on failure.</exception>
|
||||||
|
unsafe void SetModelBeamWidth(uint aBeamWidth);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Enable decoding using an external scorer.
|
/// Enable decoding using an external scorer.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
|
|
@ -14,6 +14,17 @@ namespace DeepSpeechClient
|
||||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||||
internal static extern void DS_PrintVersions();
|
internal static extern void DS_PrintVersions();
|
||||||
|
|
||||||
|
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||||
|
internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
|
||||||
|
ref IntPtr** pint);
|
||||||
|
|
||||||
|
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||||
|
internal unsafe static extern uint DS_GetModelBeamWidth(IntPtr** aCtx);
|
||||||
|
|
||||||
|
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||||
|
internal unsafe static extern ErrorCodes DS_SetModelBeamWidth(IntPtr** aCtx,
|
||||||
|
uint aBeamWidth);
|
||||||
|
|
||||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||||
internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
|
internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath,
|
||||||
uint aBeamWidth,
|
uint aBeamWidth,
|
||||||
|
|
|
@ -46,15 +46,12 @@ namespace CSharpExamples
|
||||||
extended = !string.IsNullOrWhiteSpace(GetArgument(args, "--extended"));
|
extended = !string.IsNullOrWhiteSpace(GetArgument(args, "--extended"));
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint BEAM_WIDTH = 500;
|
|
||||||
|
|
||||||
Stopwatch stopwatch = new Stopwatch();
|
Stopwatch stopwatch = new Stopwatch();
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
Console.WriteLine("Loading model...");
|
Console.WriteLine("Loading model...");
|
||||||
stopwatch.Start();
|
stopwatch.Start();
|
||||||
using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm",
|
using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm"))
|
||||||
BEAM_WIDTH))
|
|
||||||
{
|
{
|
||||||
stopwatch.Stop();
|
stopwatch.Stop();
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,8 @@ public class DeepSpeechActivity extends AppCompatActivity {
|
||||||
private void newModel(String tfliteModel) {
|
private void newModel(String tfliteModel) {
|
||||||
this._tfliteStatus.setText("Creating model");
|
this._tfliteStatus.setText("Creating model");
|
||||||
if (this._m == null) {
|
if (this._m == null) {
|
||||||
this._m = new DeepSpeechModel(tfliteModel, BEAM_WIDTH);
|
this._m = new DeepSpeechModel(tfliteModel);
|
||||||
|
this._m.setBeamWidth(BEAM_WIDTH);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,7 @@ public class BasicTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void loadDeepSpeech_basic() {
|
public void loadDeepSpeech_basic() {
|
||||||
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
|
DeepSpeechModel m = new DeepSpeechModel(modelFile);
|
||||||
m.freeModel();
|
m.freeModel();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,7 +116,8 @@ public class BasicTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void loadDeepSpeech_stt_noLM() {
|
public void loadDeepSpeech_stt_noLM() {
|
||||||
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
|
DeepSpeechModel m = new DeepSpeechModel(modelFile);
|
||||||
|
m.setBeamWidth(BEAM_WIDTH);
|
||||||
|
|
||||||
String decoded = doSTT(m, false);
|
String decoded = doSTT(m, false);
|
||||||
assertEquals("she had your dark suit in greasy wash water all year", decoded);
|
assertEquals("she had your dark suit in greasy wash water all year", decoded);
|
||||||
|
@ -125,7 +126,8 @@ public class BasicTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void loadDeepSpeech_stt_withLM() {
|
public void loadDeepSpeech_stt_withLM() {
|
||||||
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
|
DeepSpeechModel m = new DeepSpeechModel(modelFile);
|
||||||
|
m.setBeamWidth(BEAM_WIDTH);
|
||||||
m.enableExternalScorer(scorerFile);
|
m.enableExternalScorer(scorerFile);
|
||||||
|
|
||||||
String decoded = doSTT(m, false);
|
String decoded = doSTT(m, false);
|
||||||
|
@ -135,7 +137,8 @@ public class BasicTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void loadDeepSpeech_sttWithMetadata_noLM() {
|
public void loadDeepSpeech_sttWithMetadata_noLM() {
|
||||||
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
|
DeepSpeechModel m = new DeepSpeechModel(modelFile);
|
||||||
|
m.setBeamWidth(BEAM_WIDTH);
|
||||||
|
|
||||||
String decoded = doSTT(m, true);
|
String decoded = doSTT(m, true);
|
||||||
assertEquals("she had your dark suit in greasy wash water all year", decoded);
|
assertEquals("she had your dark suit in greasy wash water all year", decoded);
|
||||||
|
@ -144,7 +147,8 @@ public class BasicTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void loadDeepSpeech_sttWithMetadata_withLM() {
|
public void loadDeepSpeech_sttWithMetadata_withLM() {
|
||||||
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
|
DeepSpeechModel m = new DeepSpeechModel(modelFile);
|
||||||
|
m.setBeamWidth(BEAM_WIDTH);
|
||||||
m.enableExternalScorer(scorerFile);
|
m.enableExternalScorer(scorerFile);
|
||||||
|
|
||||||
String decoded = doSTT(m, true);
|
String decoded = doSTT(m, true);
|
||||||
|
|
|
@ -20,16 +20,35 @@ public class DeepSpeechModel {
|
||||||
* @constructor
|
* @constructor
|
||||||
*
|
*
|
||||||
* @param modelPath The path to the frozen model graph.
|
* @param modelPath The path to the frozen model graph.
|
||||||
* @param beam_width The beam width used by the decoder. A larger beam
|
|
||||||
* width generates better results at the cost of decoding
|
|
||||||
* time.
|
|
||||||
*/
|
*/
|
||||||
public DeepSpeechModel(String modelPath, int beam_width) {
|
public DeepSpeechModel(String modelPath) {
|
||||||
this._mspp = impl.new_modelstatep();
|
this._mspp = impl.new_modelstatep();
|
||||||
impl.CreateModel(modelPath, beam_width, this._mspp);
|
impl.CreateModel(modelPath, this._mspp);
|
||||||
this._msp = impl.modelstatep_value(this._mspp);
|
this._msp = impl.modelstatep_value(this._mspp);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get beam width value used by the model. If setModelBeamWidth was not
|
||||||
|
* called before, will return the default value loaded from the model file.
|
||||||
|
*
|
||||||
|
* @return Beam width value used by the model.
|
||||||
|
*/
|
||||||
|
public int beamWidth() {
|
||||||
|
return impl.GetModelBeamWidth(this._msp);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Set beam width value used by the model.
|
||||||
|
*
|
||||||
|
* @param aBeamWidth The beam width used by the model. A larger beam width value
|
||||||
|
* generates better results at the cost of decoding time.
|
||||||
|
*
|
||||||
|
* @return Zero on success, non-zero on failure.
|
||||||
|
*/
|
||||||
|
public int setBeamWidth(int beamWidth) {
|
||||||
|
return impl.SetModelBeamWidth(this._msp, beamWidth);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return the sample rate expected by the model.
|
* @brief Return the sample rate expected by the model.
|
||||||
*
|
*
|
||||||
|
|
|
@ -31,7 +31,7 @@ var parser = new argparse.ArgumentParser({addHelp: true, description: 'Running D
|
||||||
parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'});
|
parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'});
|
||||||
parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'});
|
parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'});
|
||||||
parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'});
|
parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'});
|
||||||
parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', defaultValue: 500, type: 'int'});
|
parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', type: 'int'});
|
||||||
parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'});
|
parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'});
|
||||||
parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'});
|
parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'});
|
||||||
parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'});
|
parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'});
|
||||||
|
@ -53,10 +53,14 @@ function metadataToString(metadata) {
|
||||||
|
|
||||||
console.error('Loading model from file %s', args['model']);
|
console.error('Loading model from file %s', args['model']);
|
||||||
const model_load_start = process.hrtime();
|
const model_load_start = process.hrtime();
|
||||||
var model = new Ds.Model(args['model'], args['beam_width']);
|
var model = new Ds.Model(args['model']);
|
||||||
const model_load_end = process.hrtime(model_load_start);
|
const model_load_end = process.hrtime(model_load_start);
|
||||||
console.error('Loaded model in %ds.', totalTime(model_load_end));
|
console.error('Loaded model in %ds.', totalTime(model_load_end));
|
||||||
|
|
||||||
|
if (args['beam_width']) {
|
||||||
|
model.setBeamWidth(args['beam_width']);
|
||||||
|
}
|
||||||
|
|
||||||
var desired_sample_rate = model.sampleRate();
|
var desired_sample_rate = model.sampleRate();
|
||||||
|
|
||||||
if (args['scorer']) {
|
if (args['scorer']) {
|
||||||
|
|
|
@ -25,14 +25,13 @@ if (process.platform === 'win32') {
|
||||||
* An object providing an interface to a trained DeepSpeech model.
|
* An object providing an interface to a trained DeepSpeech model.
|
||||||
*
|
*
|
||||||
* @param {string} aModelPath The path to the frozen model graph.
|
* @param {string} aModelPath The path to the frozen model graph.
|
||||||
* @param {number} aBeamWidth The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time.
|
|
||||||
*
|
*
|
||||||
* @throws on error
|
* @throws on error
|
||||||
*/
|
*/
|
||||||
function Model() {
|
function Model(aModelPath) {
|
||||||
this._impl = null;
|
this._impl = null;
|
||||||
|
|
||||||
const rets = binding.CreateModel.apply(null, arguments);
|
const rets = binding.CreateModel(aModelPath);
|
||||||
const status = rets[0];
|
const status = rets[0];
|
||||||
const impl = rets[1];
|
const impl = rets[1];
|
||||||
if (status !== 0) {
|
if (status !== 0) {
|
||||||
|
@ -42,6 +41,38 @@ function Model() {
|
||||||
this._impl = impl;
|
this._impl = impl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get beam width value used by the model. If {@link DS_SetModelBeamWidth}
|
||||||
|
* was not called before, will return the default value loaded from the
|
||||||
|
* model file.
|
||||||
|
*
|
||||||
|
* @return {number} Beam width value used by the model.
|
||||||
|
*/
|
||||||
|
Model.prototype.beamWidth = function() {
|
||||||
|
return binding.GetModelBeamWidth(this._impl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set beam width value used by the model.
|
||||||
|
*
|
||||||
|
* @param {number} The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
|
||||||
|
*
|
||||||
|
* @return {number} Zero on success, non-zero on failure.
|
||||||
|
*/
|
||||||
|
Model.prototype.setBeamWidth = function(aBeamWidth) {
|
||||||
|
return binding.SetModelBeamWidth(this._impl, aBeamWidth);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the sample rate expected by the model.
|
||||||
|
*
|
||||||
|
* @return {number} Sample rate.
|
||||||
|
*/
|
||||||
|
Model.prototype.beamWidth = function() {
|
||||||
|
return binding.GetModelBeamWidth(this._impl);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return the sample rate expected by the model.
|
* Return the sample rate expected by the model.
|
||||||
*
|
*
|
||||||
|
|
|
@ -28,15 +28,12 @@ class Model(object):
|
||||||
|
|
||||||
:param aModelPath: Path to model file to load
|
:param aModelPath: Path to model file to load
|
||||||
:type aModelPath: str
|
:type aModelPath: str
|
||||||
|
|
||||||
:param aBeamWidth: Decoder beam width
|
|
||||||
:type aBeamWidth: int
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, model_path):
|
||||||
# make sure the attribute is there if CreateModel fails
|
# make sure the attribute is there if CreateModel fails
|
||||||
self._impl = None
|
self._impl = None
|
||||||
|
|
||||||
status, impl = deepspeech.impl.CreateModel(*args, **kwargs)
|
status, impl = deepspeech.impl.CreateModel(model_path)
|
||||||
if status != 0:
|
if status != 0:
|
||||||
raise RuntimeError("CreateModel failed with error code {}".format(status))
|
raise RuntimeError("CreateModel failed with error code {}".format(status))
|
||||||
self._impl = impl
|
self._impl = impl
|
||||||
|
@ -46,6 +43,29 @@ class Model(object):
|
||||||
deepspeech.impl.FreeModel(self._impl)
|
deepspeech.impl.FreeModel(self._impl)
|
||||||
self._impl = None
|
self._impl = None
|
||||||
|
|
||||||
|
def beamWidth(self):
|
||||||
|
"""
|
||||||
|
Get beam width value used by the model. If {@link DS_SetModelBeamWidth}
|
||||||
|
was not called before, will return the default value loaded from the
|
||||||
|
model file.
|
||||||
|
|
||||||
|
:return: Beam width value used by the model.
|
||||||
|
:type: int
|
||||||
|
"""
|
||||||
|
return deepspeech.impl.GetModelBeamWidth(self._impl)
|
||||||
|
|
||||||
|
def setBeamWidth(self, beam_width):
|
||||||
|
"""
|
||||||
|
Set beam width value used by the model.
|
||||||
|
|
||||||
|
:param beam_width: The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
|
||||||
|
:type beam_width: int
|
||||||
|
|
||||||
|
:return: Zero on success, non-zero on failure.
|
||||||
|
:type: int
|
||||||
|
"""
|
||||||
|
return deepspeech.impl.SetModelBeamWidth(self._impl, beam_width)
|
||||||
|
|
||||||
def sampleRate(self):
|
def sampleRate(self):
|
||||||
"""
|
"""
|
||||||
Return the sample rate expected by the model.
|
Return the sample rate expected by the model.
|
||||||
|
|
|
@ -92,7 +92,7 @@ def main():
|
||||||
help='Path to the external scorer file')
|
help='Path to the external scorer file')
|
||||||
parser.add_argument('--audio', required=True,
|
parser.add_argument('--audio', required=True,
|
||||||
help='Path to the audio file to run (WAV format)')
|
help='Path to the audio file to run (WAV format)')
|
||||||
parser.add_argument('--beam_width', type=int, default=500,
|
parser.add_argument('--beam_width', type=int,
|
||||||
help='Beam width for the CTC decoder')
|
help='Beam width for the CTC decoder')
|
||||||
parser.add_argument('--lm_alpha', type=float,
|
parser.add_argument('--lm_alpha', type=float,
|
||||||
help='Language model weight (lm_alpha). If not specified, use default from the scorer package.')
|
help='Language model weight (lm_alpha). If not specified, use default from the scorer package.')
|
||||||
|
@ -108,10 +108,13 @@ def main():
|
||||||
|
|
||||||
print('Loading model from file {}'.format(args.model), file=sys.stderr)
|
print('Loading model from file {}'.format(args.model), file=sys.stderr)
|
||||||
model_load_start = timer()
|
model_load_start = timer()
|
||||||
ds = Model(args.model, args.beam_width)
|
ds = Model(args.model)
|
||||||
model_load_end = timer() - model_load_start
|
model_load_end = timer() - model_load_start
|
||||||
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
|
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
|
||||||
|
|
||||||
|
if args.beam_width:
|
||||||
|
ds.setModelBeamWidth(args.beam_width)
|
||||||
|
|
||||||
desired_sample_rate = ds.sampleRate()
|
desired_sample_rate = ds.sampleRate()
|
||||||
|
|
||||||
if args.scorer:
|
if args.scorer:
|
||||||
|
|
|
@ -9,12 +9,6 @@ import wave
|
||||||
from deepspeech import Model
|
from deepspeech import Model
|
||||||
|
|
||||||
|
|
||||||
# These constants control the beam search decoder
|
|
||||||
|
|
||||||
# Beam width used in the CTC decoder when building candidate transcriptions
|
|
||||||
BEAM_WIDTH = 500
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='Running DeepSpeech inference.')
|
parser = argparse.ArgumentParser(description='Running DeepSpeech inference.')
|
||||||
parser.add_argument('--model', required=True,
|
parser.add_argument('--model', required=True,
|
||||||
|
@ -27,7 +21,7 @@ def main():
|
||||||
help='Second audio file to use in interleaved streams')
|
help='Second audio file to use in interleaved streams')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ds = Model(args.model, BEAM_WIDTH)
|
ds = Model(args.model)
|
||||||
|
|
||||||
if args.scorer:
|
if args.scorer:
|
||||||
ds.enableExternalScorer(args.scorer)
|
ds.enableExternalScorer(args.scorer)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче