Multiple audio native_client benchmark

Fixes #1232
This commit is contained in:
Alexandre Lissy 2018-02-13 14:00:10 +01:00
Родитель 32c06acdf8
Коммит f54b1f6ad8
3 изменённых файлов: 133 добавлений и 43 удалений

Просмотреть файл

@ -1,14 +1,21 @@
#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
#include <dirent.h>
#include <errno.h>
#include <math.h>
#include <string.h>
#include <sox.h>
#include <time.h>
#include "deepspeech.h"
#ifdef __APPLE__
#include <unistd.h>
#endif
#include <sys/types.h>
#include <sys/stat.h>
#include <string>
#include "deepspeech.h"
#define N_CEP 26
#define N_CONTEXT 9
@ -59,36 +66,21 @@ LocalDsSTT(Model& aCtx, const short* aBuffer, size_t aBufferSize,
return res;
}
int
main(int argc, char **argv)
struct ds_audio_buffer {
char* buffer;
size_t buffer_size;
int sample_rate;
};
struct ds_audio_buffer*
GetAudioBuffer(const char* path)
{
if (argc < 4 || argc > 7) {
printf("Usage: deepspeech MODEL_PATH ALPHABET_PATH [LM_PATH] [TRIE_PATH] AUDIO_PATH [-t]\n");
printf(" MODEL_PATH\tPath to the model (protocol buffer binary file)\n");
printf(" ALPHABET_PATH\tPath to the configuration file specifying"
" the alphabet used by the network.\n");
printf(" LM_PATH\tOptional: Path to the language model binary file.\n");
printf(" TRIE_PATH\tOptional: Path to the language model trie file created with"
" native_client/generate_trie.\n");
printf(" AUDIO_PATH\tPath to the audio file to run"
" (any file format supported by libsox)\n");
printf(" -t\t\tRun in benchmark mode, output mfcc & inference time\n");
return 1;
struct ds_audio_buffer* res = (struct ds_audio_buffer*)malloc(sizeof(struct ds_audio_buffer));
if (!res) {
return NULL;
}
// Initialise DeepSpeech
Model ctx = Model(argv[1], N_CEP, N_CONTEXT, argv[2], BEAM_WIDTH);
if (argc > 5) {
ctx.enableDecoderWithLM(argv[2], argv[3], argv[4], LM_WEIGHT, WORD_COUNT_WEIGHT, VALID_WORD_COUNT_WEIGHT);
}
// Initialise SOX
assert(sox_init() == SOX_SUCCESS);
// Handle case when LM_PATH and TRIE_PATH are not passed
const char* wav_arg = (argc <= 5) ? argv[3] : argv[5];
sox_format_t* input = sox_open_read(wav_arg, NULL, NULL, NULL);
sox_format_t* input = sox_open_read(path, NULL, NULL, NULL);
assert(input);
// Resample/reformat the audio so we can pass it through the MFCC functions
@ -121,9 +113,8 @@ main(int argc, char **argv)
sox_format_t* output = sox_open_write(output_name, &target_signal,
&target_encoding, "raw", NULL, NULL);
#else
char* buffer;
size_t buffer_size;
sox_format_t* output = sox_open_memstream_write(&buffer, &buffer_size,
sox_format_t* output = sox_open_memstream_write(&res->buffer,
&res->buffer_size,
&target_signal,
&target_encoding,
"raw", NULL);
@ -131,7 +122,7 @@ main(int argc, char **argv)
assert(output);
int sampleRate = (int)output->signal.rate;
res->sample_rate = (int)output->signal.rate;
if ((int)input->signal.rate < 16000) {
fprintf(stderr, "Warning: original sample rate (%d) is lower than 16kHz. Up-sampling might produce erratic speech recognition.\n", (int)input->signal.rate);
@ -179,20 +170,31 @@ main(int argc, char **argv)
sox_close(input);
#ifdef __APPLE__
size_t buffer_size = (size_t)(output->olength * 2);
char* buffer = (char*)malloc(sizeof(char) * buffer_size);
res->buffer_size = (size_t)(output->olength * 2);
res->buffer = (char*)malloc(sizeof(char) * res->buffer_size);
FILE* output_file = fopen(output_name, "rb");
assert(fread(buffer, sizeof(char), buffer_size, output_file) == buffer_size);
assert(fread(res->buffer, sizeof(char), res->buffer_size, output_file) == res->buffer_size);
fclose(output_file);
unlink(output_name);
#endif
return res;
}
void
ProcessFile(Model& context, const char* path, bool show_times)
{
struct ds_audio_buffer* audio = GetAudioBuffer(path);
// Pass audio to DeepSpeech
// We take half of buffer_size because buffer is a char* while
// LocalDsSTT() expected a short*
struct ds_result* result = LocalDsSTT(ctx, (const short*)buffer,
buffer_size / 2, sampleRate);
free(buffer);
struct ds_result* result = LocalDsSTT(context,
(const short*)audio->buffer,
audio->buffer_size / 2,
audio->sample_rate);
free(audio->buffer);
free(audio);
if (result) {
if (result->string) {
@ -200,7 +202,7 @@ main(int argc, char **argv)
free(result->string);
}
if (!strncmp(argv[argc-1], "-t", 3)) {
if (show_times) {
printf("cpu_time_overall=%.05f cpu_time_mfcc=%.05f "
"cpu_time_infer=%.05f\n",
result->cpu_time_overall,
@ -210,6 +212,75 @@ main(int argc, char **argv)
free(result);
}
}
int
main(int argc, char **argv)
{
if (argc < 4 || argc > 7) {
printf("Usage: deepspeech MODEL_PATH ALPHABET_PATH [LM_PATH] [TRIE_PATH] AUDIO_PATH [-t]\n");
printf(" MODEL_PATH\tPath to the model (protocol buffer binary file)\n");
printf(" ALPHABET_PATH\tPath to the configuration file specifying"
" the alphabet used by the network.\n");
printf(" LM_PATH\tOptional: Path to the language model binary file.\n");
printf(" TRIE_PATH\tOptional: Path to the language model trie file created with"
" native_client/generate_trie.\n");
printf(" AUDIO_PATH\tPath to the audio file (or directory of files) to run"
" (any file format supported by libsox). \n");
printf(" -t\t\tRun in benchmark mode, output mfcc & inference time\n");
return 1;
}
// Initialise DeepSpeech
Model ctx = Model(argv[1], N_CEP, N_CONTEXT, argv[2], BEAM_WIDTH);
if (argc > 5) {
ctx.enableDecoderWithLM(argv[2], argv[3], argv[4], LM_WEIGHT, WORD_COUNT_WEIGHT, VALID_WORD_COUNT_WEIGHT);
}
// Initialise SOX
assert(sox_init() == SOX_SUCCESS);
// Handle case when LM_PATH and TRIE_PATH are not passed
const char* path = (argc <= 5) ? argv[3] : argv[5];
bool show_times = !strncmp(argv[argc-1], "-t", 3);
struct stat wav_info;
if (0 != stat(path, &wav_info)) {
printf("Error on stat: %d\n", errno);
}
switch (wav_info.st_mode & S_IFMT) {
case S_IFLNK:
case S_IFREG:
ProcessFile(ctx, path, show_times);
break;
case S_IFDIR:
{
printf("Running on directory %s\n", path);
DIR* wav_dir = opendir(path);
assert(wav_dir);
struct dirent* entry;
while ((entry = readdir(wav_dir)) != NULL) {
std::string fname = std::string(entry->d_name);
if (fname.find(".wav") == std::string::npos) {
continue;
}
std::string fullpath = std::string(path) + std::string("/") + fname;
printf("> %s\n", fullpath.c_str());
ProcessFile(ctx, fullpath.c_str(), show_times);
}
closedir(wav_dir);
}
break;
default:
printf("Unexpected type for %s: %d\n", path, (wav_info.st_mode & S_IFMT));
break;
}
// Deinitialise and quit
sox_quit();

Просмотреть файл

@ -11,3 +11,5 @@ download_material "${TASKCLUSTER_TMP_DIR}/ds" "${aot_model}"
export PATH=${TASKCLUSTER_TMP_DIR}/ds/:$PATH
run_all_inference_tests
run_multi_inference_tests

Просмотреть файл

@ -103,7 +103,7 @@ assert_working_inference()
esac
}
assert_shows_warning()
assert_shows_something()
{
stderr=$1
expected=$2
@ -139,6 +139,14 @@ assert_correct_ldc93s1()
assert_correct_inference "$1" "she had your dark suit in greasy wash water all year"
}
assert_correct_multi_ldc93s1()
{
assert_shows_something "$1" "/LDC93S1.wav%she had your dark suit in greasy wash water all year%"
assert_shows_something "$1" "/LDC93S1_pcms16le_2_44100.wav%she had your dark suit in greasy wash water all year%"
## 8k will output garbage anyway ...
# assert_shows_something "$1" "/LDC93S1_pcms16le_1_8000.wav%she hayorasryrtl lyreasy asr watal w water all year%"
}
assert_correct_ldc93s1_prodmodel()
{
assert_correct_inference "$1" "she had the duck so ingrecywachworallyear"
@ -182,7 +190,7 @@ assert_correct_ldc93s1_somodel()
assert_correct_warning_upsampling()
{
assert_shows_warning "$1" "is lower than 16kHz. Up-sampling might produce erratic speech recognition"
assert_shows_something "$1" "is lower than 16kHz. Up-sampling might produce erratic speech recognition"
}
run_all_inference_tests()
@ -235,6 +243,15 @@ run_prod_inference_tests()
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
}
run_multi_inference_tests()
{
multi_phrase_pbmodel_nolm=$(deepspeech ${TASKCLUSTER_TMP_DIR}/${model_name} ${TASKCLUSTER_TMP_DIR}/alphabet.txt ${TASKCLUSTER_TMP_DIR}/ | tr '\n' '%')
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_nolm}"
multi_phrase_pbmodel_withlm=$(deepspeech ${TASKCLUSTER_TMP_DIR}/${model_name} ${TASKCLUSTER_TMP_DIR}/alphabet.txt ${TASKCLUSTER_TMP_DIR}/lm.binary ${TASKCLUSTER_TMP_DIR}/trie ${TASKCLUSTER_TMP_DIR}/ | tr '\n' '%')
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}"
}
generic_download_tarxz()
{
target_dir=$1