зеркало из https://github.com/mozilla/DeepSpeech.git
Родитель
32c06acdf8
Коммит
f54b1f6ad8
|
@ -1,14 +1,21 @@
|
|||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <assert.h>
|
||||
#include <dirent.h>
|
||||
#include <errno.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include <sox.h>
|
||||
#include <time.h>
|
||||
#include "deepspeech.h"
|
||||
#ifdef __APPLE__
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <sys/stat.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "deepspeech.h"
|
||||
|
||||
#define N_CEP 26
|
||||
#define N_CONTEXT 9
|
||||
|
@ -59,36 +66,21 @@ LocalDsSTT(Model& aCtx, const short* aBuffer, size_t aBufferSize,
|
|||
return res;
|
||||
}
|
||||
|
||||
int
|
||||
main(int argc, char **argv)
|
||||
struct ds_audio_buffer {
|
||||
char* buffer;
|
||||
size_t buffer_size;
|
||||
int sample_rate;
|
||||
};
|
||||
|
||||
struct ds_audio_buffer*
|
||||
GetAudioBuffer(const char* path)
|
||||
{
|
||||
if (argc < 4 || argc > 7) {
|
||||
printf("Usage: deepspeech MODEL_PATH ALPHABET_PATH [LM_PATH] [TRIE_PATH] AUDIO_PATH [-t]\n");
|
||||
printf(" MODEL_PATH\tPath to the model (protocol buffer binary file)\n");
|
||||
printf(" ALPHABET_PATH\tPath to the configuration file specifying"
|
||||
" the alphabet used by the network.\n");
|
||||
printf(" LM_PATH\tOptional: Path to the language model binary file.\n");
|
||||
printf(" TRIE_PATH\tOptional: Path to the language model trie file created with"
|
||||
" native_client/generate_trie.\n");
|
||||
printf(" AUDIO_PATH\tPath to the audio file to run"
|
||||
" (any file format supported by libsox)\n");
|
||||
printf(" -t\t\tRun in benchmark mode, output mfcc & inference time\n");
|
||||
return 1;
|
||||
struct ds_audio_buffer* res = (struct ds_audio_buffer*)malloc(sizeof(struct ds_audio_buffer));
|
||||
if (!res) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Initialise DeepSpeech
|
||||
Model ctx = Model(argv[1], N_CEP, N_CONTEXT, argv[2], BEAM_WIDTH);
|
||||
|
||||
if (argc > 5) {
|
||||
ctx.enableDecoderWithLM(argv[2], argv[3], argv[4], LM_WEIGHT, WORD_COUNT_WEIGHT, VALID_WORD_COUNT_WEIGHT);
|
||||
}
|
||||
|
||||
// Initialise SOX
|
||||
assert(sox_init() == SOX_SUCCESS);
|
||||
|
||||
// Handle case when LM_PATH and TRIE_PATH are not passed
|
||||
const char* wav_arg = (argc <= 5) ? argv[3] : argv[5];
|
||||
sox_format_t* input = sox_open_read(wav_arg, NULL, NULL, NULL);
|
||||
sox_format_t* input = sox_open_read(path, NULL, NULL, NULL);
|
||||
assert(input);
|
||||
|
||||
// Resample/reformat the audio so we can pass it through the MFCC functions
|
||||
|
@ -121,9 +113,8 @@ main(int argc, char **argv)
|
|||
sox_format_t* output = sox_open_write(output_name, &target_signal,
|
||||
&target_encoding, "raw", NULL, NULL);
|
||||
#else
|
||||
char* buffer;
|
||||
size_t buffer_size;
|
||||
sox_format_t* output = sox_open_memstream_write(&buffer, &buffer_size,
|
||||
sox_format_t* output = sox_open_memstream_write(&res->buffer,
|
||||
&res->buffer_size,
|
||||
&target_signal,
|
||||
&target_encoding,
|
||||
"raw", NULL);
|
||||
|
@ -131,7 +122,7 @@ main(int argc, char **argv)
|
|||
|
||||
assert(output);
|
||||
|
||||
int sampleRate = (int)output->signal.rate;
|
||||
res->sample_rate = (int)output->signal.rate;
|
||||
|
||||
if ((int)input->signal.rate < 16000) {
|
||||
fprintf(stderr, "Warning: original sample rate (%d) is lower than 16kHz. Up-sampling might produce erratic speech recognition.\n", (int)input->signal.rate);
|
||||
|
@ -179,20 +170,31 @@ main(int argc, char **argv)
|
|||
sox_close(input);
|
||||
|
||||
#ifdef __APPLE__
|
||||
size_t buffer_size = (size_t)(output->olength * 2);
|
||||
char* buffer = (char*)malloc(sizeof(char) * buffer_size);
|
||||
res->buffer_size = (size_t)(output->olength * 2);
|
||||
res->buffer = (char*)malloc(sizeof(char) * res->buffer_size);
|
||||
FILE* output_file = fopen(output_name, "rb");
|
||||
assert(fread(buffer, sizeof(char), buffer_size, output_file) == buffer_size);
|
||||
assert(fread(res->buffer, sizeof(char), res->buffer_size, output_file) == res->buffer_size);
|
||||
fclose(output_file);
|
||||
unlink(output_name);
|
||||
#endif
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void
|
||||
ProcessFile(Model& context, const char* path, bool show_times)
|
||||
{
|
||||
struct ds_audio_buffer* audio = GetAudioBuffer(path);
|
||||
|
||||
// Pass audio to DeepSpeech
|
||||
// We take half of buffer_size because buffer is a char* while
|
||||
// LocalDsSTT() expected a short*
|
||||
struct ds_result* result = LocalDsSTT(ctx, (const short*)buffer,
|
||||
buffer_size / 2, sampleRate);
|
||||
free(buffer);
|
||||
struct ds_result* result = LocalDsSTT(context,
|
||||
(const short*)audio->buffer,
|
||||
audio->buffer_size / 2,
|
||||
audio->sample_rate);
|
||||
free(audio->buffer);
|
||||
free(audio);
|
||||
|
||||
if (result) {
|
||||
if (result->string) {
|
||||
|
@ -200,7 +202,7 @@ main(int argc, char **argv)
|
|||
free(result->string);
|
||||
}
|
||||
|
||||
if (!strncmp(argv[argc-1], "-t", 3)) {
|
||||
if (show_times) {
|
||||
printf("cpu_time_overall=%.05f cpu_time_mfcc=%.05f "
|
||||
"cpu_time_infer=%.05f\n",
|
||||
result->cpu_time_overall,
|
||||
|
@ -210,6 +212,75 @@ main(int argc, char **argv)
|
|||
|
||||
free(result);
|
||||
}
|
||||
}
|
||||
|
||||
int
|
||||
main(int argc, char **argv)
|
||||
{
|
||||
if (argc < 4 || argc > 7) {
|
||||
printf("Usage: deepspeech MODEL_PATH ALPHABET_PATH [LM_PATH] [TRIE_PATH] AUDIO_PATH [-t]\n");
|
||||
printf(" MODEL_PATH\tPath to the model (protocol buffer binary file)\n");
|
||||
printf(" ALPHABET_PATH\tPath to the configuration file specifying"
|
||||
" the alphabet used by the network.\n");
|
||||
printf(" LM_PATH\tOptional: Path to the language model binary file.\n");
|
||||
printf(" TRIE_PATH\tOptional: Path to the language model trie file created with"
|
||||
" native_client/generate_trie.\n");
|
||||
printf(" AUDIO_PATH\tPath to the audio file (or directory of files) to run"
|
||||
" (any file format supported by libsox). \n");
|
||||
printf(" -t\t\tRun in benchmark mode, output mfcc & inference time\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Initialise DeepSpeech
|
||||
Model ctx = Model(argv[1], N_CEP, N_CONTEXT, argv[2], BEAM_WIDTH);
|
||||
|
||||
if (argc > 5) {
|
||||
ctx.enableDecoderWithLM(argv[2], argv[3], argv[4], LM_WEIGHT, WORD_COUNT_WEIGHT, VALID_WORD_COUNT_WEIGHT);
|
||||
}
|
||||
|
||||
// Initialise SOX
|
||||
assert(sox_init() == SOX_SUCCESS);
|
||||
|
||||
// Handle case when LM_PATH and TRIE_PATH are not passed
|
||||
const char* path = (argc <= 5) ? argv[3] : argv[5];
|
||||
bool show_times = !strncmp(argv[argc-1], "-t", 3);
|
||||
|
||||
struct stat wav_info;
|
||||
if (0 != stat(path, &wav_info)) {
|
||||
printf("Error on stat: %d\n", errno);
|
||||
}
|
||||
|
||||
switch (wav_info.st_mode & S_IFMT) {
|
||||
case S_IFLNK:
|
||||
case S_IFREG:
|
||||
ProcessFile(ctx, path, show_times);
|
||||
break;
|
||||
|
||||
case S_IFDIR:
|
||||
{
|
||||
printf("Running on directory %s\n", path);
|
||||
DIR* wav_dir = opendir(path);
|
||||
assert(wav_dir);
|
||||
|
||||
struct dirent* entry;
|
||||
while ((entry = readdir(wav_dir)) != NULL) {
|
||||
std::string fname = std::string(entry->d_name);
|
||||
if (fname.find(".wav") == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string fullpath = std::string(path) + std::string("/") + fname;
|
||||
printf("> %s\n", fullpath.c_str());
|
||||
ProcessFile(ctx, fullpath.c_str(), show_times);
|
||||
}
|
||||
closedir(wav_dir);
|
||||
}
|
||||
break;
|
||||
|
||||
default:
|
||||
printf("Unexpected type for %s: %d\n", path, (wav_info.st_mode & S_IFMT));
|
||||
break;
|
||||
}
|
||||
|
||||
// Deinitialise and quit
|
||||
sox_quit();
|
||||
|
|
|
@ -11,3 +11,5 @@ download_material "${TASKCLUSTER_TMP_DIR}/ds" "${aot_model}"
|
|||
export PATH=${TASKCLUSTER_TMP_DIR}/ds/:$PATH
|
||||
|
||||
run_all_inference_tests
|
||||
|
||||
run_multi_inference_tests
|
||||
|
|
|
@ -103,7 +103,7 @@ assert_working_inference()
|
|||
esac
|
||||
}
|
||||
|
||||
assert_shows_warning()
|
||||
assert_shows_something()
|
||||
{
|
||||
stderr=$1
|
||||
expected=$2
|
||||
|
@ -139,6 +139,14 @@ assert_correct_ldc93s1()
|
|||
assert_correct_inference "$1" "she had your dark suit in greasy wash water all year"
|
||||
}
|
||||
|
||||
assert_correct_multi_ldc93s1()
|
||||
{
|
||||
assert_shows_something "$1" "/LDC93S1.wav%she had your dark suit in greasy wash water all year%"
|
||||
assert_shows_something "$1" "/LDC93S1_pcms16le_2_44100.wav%she had your dark suit in greasy wash water all year%"
|
||||
## 8k will output garbage anyway ...
|
||||
# assert_shows_something "$1" "/LDC93S1_pcms16le_1_8000.wav%she hayorasryrtl lyreasy asr watal w water all year%"
|
||||
}
|
||||
|
||||
assert_correct_ldc93s1_prodmodel()
|
||||
{
|
||||
assert_correct_inference "$1" "she had the duck so ingrecywachworallyear"
|
||||
|
@ -182,7 +190,7 @@ assert_correct_ldc93s1_somodel()
|
|||
|
||||
assert_correct_warning_upsampling()
|
||||
{
|
||||
assert_shows_warning "$1" "is lower than 16kHz. Up-sampling might produce erratic speech recognition"
|
||||
assert_shows_something "$1" "is lower than 16kHz. Up-sampling might produce erratic speech recognition"
|
||||
}
|
||||
|
||||
run_all_inference_tests()
|
||||
|
@ -235,6 +243,15 @@ run_prod_inference_tests()
|
|||
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
|
||||
}
|
||||
|
||||
run_multi_inference_tests()
|
||||
{
|
||||
multi_phrase_pbmodel_nolm=$(deepspeech ${TASKCLUSTER_TMP_DIR}/${model_name} ${TASKCLUSTER_TMP_DIR}/alphabet.txt ${TASKCLUSTER_TMP_DIR}/ | tr '\n' '%')
|
||||
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_nolm}"
|
||||
|
||||
multi_phrase_pbmodel_withlm=$(deepspeech ${TASKCLUSTER_TMP_DIR}/${model_name} ${TASKCLUSTER_TMP_DIR}/alphabet.txt ${TASKCLUSTER_TMP_DIR}/lm.binary ${TASKCLUSTER_TMP_DIR}/trie ${TASKCLUSTER_TMP_DIR}/ | tr '\n' '%')
|
||||
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}"
|
||||
}
|
||||
|
||||
generic_download_tarxz()
|
||||
{
|
||||
target_dir=$1
|
||||
|
|
Загрузка…
Ссылка в новой задаче