зеркало из https://github.com/mozilla/kaldi.git
trunk: add multi-threaded online-nnet2 decoding program, online2-wav-nnet2-latgen-threaded, which does decoding and nnet evaluation in different threads. Usage is otherwise similar to online2-wav-nnet2-latgen-faster.
git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@4844 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
This commit is contained in:
Родитель
41a8f9b3cb
Коммит
59736ea848
|
@ -148,3 +148,26 @@ if [ $stage -le 13 ]; then
|
|||
fi
|
||||
|
||||
exit 0;
|
||||
###### Comment out the "exit 0" above to run the multi-threaded decoding. #####
|
||||
|
||||
if [ $stage -le 14 ]; then
|
||||
# Demonstrate the multi-threaded decoding.
|
||||
# put back the pp
|
||||
test=dev_clean
|
||||
steps/online/nnet2/decode.sh --threaded true \
|
||||
--config conf/decode.config --cmd "$decode_cmd" --nj 30 \
|
||||
--per-utt true exp/tri6b/graph_pp_tgsmall data/$test \
|
||||
${dir}_online/decode_pp_${test}_tgsmall_utt_threaded || exit 1;
|
||||
fi
|
||||
|
||||
if [ $stage -le 15 ]; then
|
||||
# Demonstrate the multi-threaded decoding with endpointing.
|
||||
# put back the pp
|
||||
test=dev_clean
|
||||
steps/online/nnet2/decode.sh --threaded true --do-endpointing true \
|
||||
--config conf/decode.config --cmd "$decode_cmd" --nj 30 \
|
||||
--per-utt true exp/tri6b/graph_pp_tgsmall data/$test \
|
||||
${dir}_online/decode_pp_${test}_tgsmall_utt_threaded_ep || exit 1;
|
||||
fi
|
||||
|
||||
exit 0;
|
||||
|
|
|
@ -151,10 +151,11 @@ utils/mkgraph.sh data/lang_test_4g exp/tri3b exp/tri3b/graph_4g || exit 1;
|
|||
steps/decode_fmllr.sh --cmd "$decode_cmd" --nj 7 \
|
||||
exp/tri3b/graph_4g data/test1k exp/tri3b/decode_4g_test1k || exit 1;
|
||||
|
||||
# Train RNN for reranking
|
||||
local/sprak_train_rnnlms.sh data/local/dict data/dev/transcripts.uniq data/local/rnnlms/g_c380_d1k_h100_v130k
|
||||
# Consumes a lot of memory! Do not run in parallel
|
||||
local/sprak_run_rnnlms_tri3b.sh data/lang_test_3g data/local/rnnlms/g_c380_d1k_h100_v130k data/test1k exp/tri3b/decode_3g_test1k
|
||||
# This is commented out for now as it's not important for the main recipe.
|
||||
## Train RNN for reranking
|
||||
#local/sprak_train_rnnlms.sh data/local/dict data/dev/transcripts.uniq data/local/rnnlms/g_c380_d1k_h100_v130k
|
||||
## Consumes a lot of memory! Do not run in parallel
|
||||
#local/sprak_run_rnnlms_tri3b.sh data/lang_test_3g data/local/rnnlms/g_c380_d1k_h100_v130k data/test1k exp/tri3b/decode_3g_test1k
|
||||
|
||||
|
||||
# From 3b system
|
||||
|
|
|
@ -8,12 +8,14 @@ stage=0
|
|||
nj=4
|
||||
cmd=run.pl
|
||||
max_active=7000
|
||||
threaded=false
|
||||
modify_ivector_config=false # only relevant to threaded decoder.
|
||||
beam=15.0
|
||||
lattice_beam=6.0
|
||||
acwt=0.1 # note: only really affects adaptation and pruning (scoring is on
|
||||
# lattices).
|
||||
per_utt=false
|
||||
online=true
|
||||
online=true # only relevant to non-threaded decoder.
|
||||
do_endpointing=false
|
||||
do_speex_compressing=false
|
||||
scoring_opts=
|
||||
|
@ -92,9 +94,23 @@ if $do_endpointing; then
|
|||
wav_rspecifier="$wav_rspecifier extend-wav-with-silence ark:- ark:- |"
|
||||
fi
|
||||
|
||||
|
||||
|
||||
if $threaded; then
|
||||
decoder=online2-wav-nnet2-latgen-threaded
|
||||
# note: the decoder actually uses 4 threads, but the average usage will normally
|
||||
# be more like 2.
|
||||
parallel_opts="--num-threads 2"
|
||||
opts="--modify-ivector-config=$modify_ivector_config --verbose=1"
|
||||
else
|
||||
decoder=online2-wav-nnet2-latgen-faster
|
||||
parallel_opts=
|
||||
opts="--online=$online"
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ]; then
|
||||
$cmd JOB=1:$nj $dir/log/decode.JOB.log \
|
||||
online2-wav-nnet2-latgen-faster --online=$online --do-endpointing=$do_endpointing \
|
||||
$cmd $parallel_opts JOB=1:$nj $dir/log/decode.JOB.log \
|
||||
$decoder $opts --do-endpointing=$do_endpointing \
|
||||
--config=$srcdir/conf/online_nnet2_decoding.conf \
|
||||
--max-active=$max_active --beam=$beam --lattice-beam=$lattice_beam \
|
||||
--acoustic-scale=$acwt --word-symbol-table=$graphdir/words.txt \
|
||||
|
|
|
@ -3,7 +3,7 @@ all:
|
|||
|
||||
include ../kaldi.mk
|
||||
|
||||
TESTFILES = kaldi-math-test io-funcs-test kaldi-error-test
|
||||
TESTFILES = kaldi-math-test io-funcs-test kaldi-error-test timer-test
|
||||
|
||||
OBJFILES = kaldi-math.o kaldi-error.o io-funcs.o kaldi-utils.o
|
||||
|
||||
|
|
|
@ -128,14 +128,19 @@ std::string KaldiGetStackTrace() {
|
|||
|
||||
void KaldiAssertFailure_(const char *func, const char *file,
|
||||
int32 line, const char *cond_str) {
|
||||
std::cerr << "KALDI_ASSERT: at " << GetProgramName() << func << ':'
|
||||
std::ostringstream ss;
|
||||
ss << "KALDI_ASSERT: at " << GetProgramName() << func << ':'
|
||||
<< GetShortFileName(file)
|
||||
<< ':' << line << ", failed: " << cond_str << '\n';
|
||||
#ifdef HAVE_EXECINFO_H
|
||||
std::cerr << "Stack trace is:\n" << KaldiGetStackTrace();
|
||||
ss << "Stack trace is:\n" << KaldiGetStackTrace();
|
||||
#endif
|
||||
std::cerr << ss.str();
|
||||
std::cerr.flush();
|
||||
abort(); // Will later throw instead if needed.
|
||||
// We used to call abort() here, but switch to throwing an exception
|
||||
// (like KALDI_ERR) because it's easier to deal with in multi-threaded
|
||||
// code.
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,15 @@
|
|||
#include <string>
|
||||
#include "base/kaldi-common.h"
|
||||
|
||||
|
||||
#ifdef _WIN32_WINNT_WIN8
|
||||
#include <Synchapi.h>
|
||||
#elif defined (_WIN32) || defined(_MSC_VER) || defined(MINGW)
|
||||
#include <Windows.h>
|
||||
#else
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
std::string CharToString(const char &c) {
|
||||
|
@ -30,4 +39,12 @@ std::string CharToString(const char &c) {
|
|||
return (std::string) buf;
|
||||
}
|
||||
|
||||
void Sleep(float seconds) {
|
||||
#if defined(_MSC_VER) || defined(MINGW)
|
||||
::Sleep(static_cast<int>(seconds * 1000.0));
|
||||
#else
|
||||
usleep(static_cast<int>(seconds * 1000000.0));
|
||||
#endif
|
||||
}
|
||||
|
||||
} // end namespace kaldi
|
||||
|
|
|
@ -78,6 +78,10 @@ inline int MachineIsLittleEndian() {
|
|||
return (*reinterpret_cast<char*>(&check) != 0);
|
||||
}
|
||||
|
||||
// This function kaldi::Sleep() provides a portable way to sleep for a possibly fractional
|
||||
// number of seconds. On Windows it's only accurate to microseconds.
|
||||
void Sleep(float seconds);
|
||||
|
||||
}
|
||||
|
||||
#define KALDI_SWAP8(a) { \
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
// base/timer-test.cc
|
||||
|
||||
// Copyright 2009-2011 Microsoft Corporation
|
||||
// 2014 Johns Hopkins University (author: Daniel Povey)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
|
@ -19,28 +20,27 @@
|
|||
|
||||
#include "base/timer.h"
|
||||
#include "base/kaldi-common.h"
|
||||
|
||||
#include "base/kaldi-utils.h"
|
||||
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
void TimerTest() {
|
||||
|
||||
float time_secs = 0.025 * (rand() % 10);
|
||||
std::cout << "target is " << time_secs << "\n";
|
||||
Timer timer;
|
||||
#if defined(_MSC_VER) || defined(MINGW)
|
||||
Sleep(1000);
|
||||
#else
|
||||
sleep(1);
|
||||
#endif
|
||||
Sleep(time_secs);
|
||||
BaseFloat f = timer.Elapsed();
|
||||
std::cout << "time is " << f;
|
||||
KALDI_ASSERT(fabs(1.0 - f) < 0.1);
|
||||
std::cout << "time is " << f << std::endl;
|
||||
if (fabs(time_secs - f) > 0.05)
|
||||
KALDI_ERR << "Timer fail: waited " << f << " seconds instead of "
|
||||
<< time_secs << " secs.";
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
for (int i = 0; i < 4; i++)
|
||||
kaldi::TimerTest();
|
||||
}
|
||||
|
||||
|
|
|
@ -19,9 +19,11 @@
|
|||
#ifndef KALDI_BASE_TIMER_H_
|
||||
#define KALDI_BASE_TIMER_H_
|
||||
|
||||
#if defined(_MSC_VER) || defined(MINGW)
|
||||
|
||||
#include "base/kaldi-utils.h"
|
||||
// Note: Sleep(float secs) is included in base/kaldi-utils.h.
|
||||
|
||||
|
||||
#if defined(_MSC_VER) || defined(MINGW)
|
||||
|
||||
namespace kaldi
|
||||
{
|
||||
|
|
|
@ -83,6 +83,88 @@ class DecodableMatrixScaledMapped: public DecodableInterface {
|
|||
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixScaledMapped);
|
||||
};
|
||||
|
||||
/**
|
||||
This decodable class returns log-likes stored in a matrix; it supports
|
||||
repeatedly writing to the matrix and setting a time-offset representing the
|
||||
frame-index of the first row of the matrix. It's intended for use in
|
||||
multi-threaded decoding; mutex and semaphores are not included. External
|
||||
code will call SetLoglikes() each time more log-likelihods are available.
|
||||
If you try to access a log-likelihood that's no longer available because
|
||||
the frame index is less than the current offset, it is of course an error.
|
||||
*/
|
||||
class DecodableMatrixMappedOffset: public DecodableInterface {
|
||||
public:
|
||||
DecodableMatrixMappedOffset(const TransitionModel &tm):
|
||||
trans_model_(tm), frame_offset_(0), input_is_finished_(false) { }
|
||||
|
||||
|
||||
|
||||
virtual int32 NumFramesReady() { return frame_offset_ + loglikes_.NumRows(); }
|
||||
|
||||
// this is not part of the generic Decodable interface.
|
||||
int32 FirstAvailableFrame() { return frame_offset_; }
|
||||
|
||||
// This function is destructive of the input "loglikes" because it may
|
||||
// under some circumstances do a shallow copy using Swap(). This function
|
||||
// appends loglikes to any existing likelihoods you've previously supplied.
|
||||
// frames_to_discard, if nonzero, will discard that number of previously
|
||||
// available frames, from the left, advancing FirstAvailableFrame() by
|
||||
// a number equal to frames_to_discard. You should only set frames_to_discard
|
||||
// to nonzero if you know your decoder won't want to access the loglikes
|
||||
// for older frames.
|
||||
void AcceptLoglikes(Matrix<BaseFloat> *loglikes,
|
||||
int32 frames_to_discard) {
|
||||
if (loglikes->NumRows() == 0) return;
|
||||
KALDI_ASSERT(loglikes->NumCols() == trans_model_.NumPdfs());
|
||||
KALDI_ASSERT(frames_to_discard <= loglikes_.NumRows() &&
|
||||
frames_to_discard >= 0);
|
||||
if (frames_to_discard == loglikes_.NumRows()) {
|
||||
loglikes_.Swap(loglikes);
|
||||
loglikes->Resize(0, 0);
|
||||
} else {
|
||||
int32 old_rows_kept = loglikes_.NumRows() - frames_to_discard,
|
||||
new_num_rows = old_rows_kept + loglikes->NumRows();
|
||||
Matrix<BaseFloat> new_loglikes(new_num_rows, loglikes->NumCols());
|
||||
new_loglikes.RowRange(0, old_rows_kept).CopyFromMat(
|
||||
loglikes_.RowRange(frames_to_discard, old_rows_kept));
|
||||
new_loglikes.RowRange(old_rows_kept, loglikes->NumRows()).CopyFromMat(
|
||||
*loglikes);
|
||||
loglikes_.Swap(&new_loglikes);
|
||||
}
|
||||
frame_offset_ += frames_to_discard;
|
||||
}
|
||||
|
||||
void InputIsFinished() { input_is_finished_ = true; }
|
||||
|
||||
virtual int32 NumFramesReady() const {
|
||||
return loglikes_.NumRows() + frame_offset_;
|
||||
}
|
||||
|
||||
virtual bool IsLastFrame(int32 frame) const {
|
||||
KALDI_ASSERT(frame < NumFramesReady());
|
||||
return (frame == NumFramesReady() - 1 && input_is_finished_);
|
||||
}
|
||||
|
||||
virtual BaseFloat LogLikelihood(int32 frame, int32 tid) {
|
||||
int32 index = frame - frame_offset_;
|
||||
KALDI_ASSERT(index >= 0 && index < loglikes_.NumRows());
|
||||
return loglikes_(index, trans_model_.TransitionIdToPdf(tid));
|
||||
}
|
||||
|
||||
|
||||
|
||||
virtual int32 NumIndices() const { return trans_model_.NumTransitionIds(); }
|
||||
|
||||
// nothing special to do in destructor.
|
||||
virtual ~DecodableMatrixMappedOffset() { }
|
||||
private:
|
||||
const TransitionModel &trans_model_; // for tid to pdf mapping
|
||||
Matrix<BaseFloat> loglikes_;
|
||||
int32 frame_offset_;
|
||||
bool input_is_finished_;
|
||||
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixMappedOffset);
|
||||
};
|
||||
|
||||
|
||||
class DecodableMatrixScaled: public DecodableInterface {
|
||||
public:
|
||||
|
|
|
@ -19,11 +19,9 @@
|
|||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Note on svn: this file is "upstream" from lattice-faster-online-decoder.h,
|
||||
// and changes in this file should be merged into
|
||||
// lattice-faster-online-decoder.h, after committing the changes to this file,
|
||||
// using the command
|
||||
// svn merge ^/sandbox/online/src/decoder/lattice-faster-decoder.h lattice-faster-online-decoder.h
|
||||
// Note: this file is "upstream" from lattice-faster-online-decoder.h,
|
||||
// and changes in this file should be made to lattice-faster-online-decoder.h,
|
||||
// if applicable.
|
||||
|
||||
#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
|
||||
#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_
|
||||
|
@ -165,8 +163,8 @@ class LatticeFasterDecoder {
|
|||
bool use_final_probs = true) const;
|
||||
|
||||
/// InitDecoding initializes the decoding, and should only be used if you
|
||||
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need
|
||||
/// to call this. You can call InitDecoding if you have already decoded an
|
||||
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need to
|
||||
/// call this. You can also call InitDecoding if you have already decoded an
|
||||
/// utterance and want to start with a new utterance.
|
||||
void InitDecoding();
|
||||
|
||||
|
@ -408,6 +406,7 @@ class LatticeFasterDecoder {
|
|||
|
||||
void ClearActiveTokens();
|
||||
|
||||
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoder);
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -148,8 +148,8 @@ class LatticeFasterOnlineDecoder {
|
|||
|
||||
|
||||
/// InitDecoding initializes the decoding, and should only be used if you
|
||||
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need
|
||||
/// to call this. You can call InitDecoding if you have already decoded an
|
||||
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need to
|
||||
/// call this. You can also call InitDecoding if you have already decoded an
|
||||
/// utterance and want to start with a new utterance.
|
||||
void InitDecoding();
|
||||
|
||||
|
@ -398,6 +398,8 @@ class LatticeFasterOnlineDecoder {
|
|||
|
||||
void ClearActiveTokens();
|
||||
|
||||
|
||||
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterOnlineDecoder);
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -604,7 +604,7 @@ void OnlineIvectorEstimationStats::GetIvector(
|
|||
ivector->SetZero();
|
||||
(*ivector)(0) = prior_offset_;
|
||||
}
|
||||
KALDI_VLOG(3) << "Objective function improvement from estimating the "
|
||||
KALDI_VLOG(4) << "Objective function improvement from estimating the "
|
||||
<< "iVector (vs. default value) is "
|
||||
<< ObjfChange(*ivector);
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ OBJFILES = nnet-component.o nnet-nnet.o train-nnet.o train-nnet-ensemble.o nnet-
|
|||
get-feature-transform.o widen-nnet.o nnet-precondition-online.o \
|
||||
nnet-example-functions.o nnet-compute-discriminative.o \
|
||||
nnet-compute-discriminative-parallel.o online-nnet2-decodable.o \
|
||||
train-nnet-perturbed.o
|
||||
train-nnet-perturbed.o nnet-compute-online.o
|
||||
|
||||
LIBNAME = kaldi-nnet2
|
||||
|
||||
|
|
|
@ -20,16 +20,6 @@
|
|||
#include "nnet2/nnet-component.h"
|
||||
#include "util/common-utils.h"
|
||||
|
||||
#ifdef _WIN32_WINNT_WIN8
|
||||
#include <Synchapi.h>
|
||||
#define sleep Sleep
|
||||
#elif _WIN32
|
||||
#include <Windows.h>
|
||||
#define sleep Sleep
|
||||
#else
|
||||
#include <unistd.h> // for sleep().
|
||||
#endif
|
||||
|
||||
namespace kaldi {
|
||||
namespace nnet2 {
|
||||
|
||||
|
@ -332,7 +322,7 @@ void UnitTestAffineComponent() {
|
|||
mat.SetRandn();
|
||||
mat.Scale(param_stddev);
|
||||
WriteKaldiObject(mat, "tmpf", true);
|
||||
sleep(1);
|
||||
Sleep(0.5);
|
||||
component.Init(learning_rate, "tmpf");
|
||||
unlink("tmpf");
|
||||
}
|
||||
|
@ -433,7 +423,7 @@ void UnitTestAffineComponentPreconditioned() {
|
|||
mat.SetRandn();
|
||||
mat.Scale(param_stddev);
|
||||
WriteKaldiObject(mat, "tmpf", true);
|
||||
sleep(1);
|
||||
Sleep(0.5);
|
||||
component.Init(learning_rate, alpha, max_change, "tmpf");
|
||||
unlink("tmpf");
|
||||
}
|
||||
|
@ -467,7 +457,7 @@ void UnitTestAffineComponentPreconditionedOnline() {
|
|||
mat.SetRandn();
|
||||
mat.Scale(param_stddev);
|
||||
WriteKaldiObject(mat, "tmpf", true);
|
||||
sleep(1);
|
||||
Sleep(0.5);
|
||||
component.Init(learning_rate, rank_in, rank_out,
|
||||
update_period, num_samples_history, alpha,
|
||||
max_change_per_sample, "tmpf");
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
// nnet2/nnet-compute-online.cc
|
||||
|
||||
// Copyright 2014 Johns Hopkins University (author: Daniel Povey)
|
||||
// Guoguo Chen
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "nnet2/nnet-compute-online.h"
|
||||
|
||||
namespace kaldi {
|
||||
namespace nnet2 {
|
||||
|
||||
NnetOnlineComputer::NnetOnlineComputer(const Nnet &nnet, bool pad_input)
|
||||
: nnet_(nnet), pad_input_(pad_input),
|
||||
is_first_chunk_(true), finished_(false) {
|
||||
data_.resize(nnet_.NumComponents() + 1);
|
||||
unused_input_.Resize(0, 0);
|
||||
// WARNING: we always pad in this dirty implementation.
|
||||
pad_input_ = true;
|
||||
}
|
||||
|
||||
void NnetOnlineComputer::Compute(const CuMatrixBase<BaseFloat> &input,
|
||||
CuMatrix<BaseFloat> *output) {
|
||||
KALDI_ASSERT(output != NULL);
|
||||
KALDI_ASSERT(!finished_);
|
||||
int32 dim = input.NumCols();
|
||||
|
||||
// If input is empty, we also set output to zero size.
|
||||
if (input.NumRows() == 0) {
|
||||
output->Resize(0, 0);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if feature dimension matches that required by the neural network.
|
||||
if (dim != nnet_.InputDim()) {
|
||||
KALDI_ERR << "Feature dimension is " << dim << ", but network expects "
|
||||
<< nnet_.InputDim();
|
||||
}
|
||||
|
||||
// Pad at the start of the file if necessary.
|
||||
if (pad_input_ && is_first_chunk_) {
|
||||
KALDI_ASSERT(unused_input_.NumRows() == 0);
|
||||
unused_input_.Resize(nnet_.LeftContext(), dim);
|
||||
for (int32 i = 0; i < nnet_.LeftContext(); i++)
|
||||
unused_input_.Row(i).CopyFromVec(input.Row(0));
|
||||
is_first_chunk_ = false;
|
||||
}
|
||||
|
||||
int32 num_rows = unused_input_.NumRows() + input.NumRows();
|
||||
|
||||
// Splice unused_input_ and input.
|
||||
CuMatrix<BaseFloat> &input_data(data_[0]);
|
||||
input_data.Resize(num_rows, dim);
|
||||
input_data.Range(0, unused_input_.NumRows(),
|
||||
0, dim).CopyFromMat(unused_input_);
|
||||
input_data.Range(unused_input_.NumRows(), input.NumRows(),
|
||||
0, dim).CopyFromMat(input);
|
||||
|
||||
if (num_rows > nnet_.LeftContext() + nnet_.RightContext()) {
|
||||
nnet_.ComputeChunkInfo(num_rows, 1, &chunk_info_);
|
||||
Propagate();
|
||||
*output = data_.back();
|
||||
} else {
|
||||
output->Resize(0, 0);
|
||||
}
|
||||
|
||||
// Now store the part of input that will be needed in the next call of
|
||||
// Compute().
|
||||
int32 unused_num_rows = nnet_.LeftContext() + nnet_.RightContext();
|
||||
if (unused_num_rows > num_rows) { unused_num_rows = num_rows; }
|
||||
unused_input_.Resize(unused_num_rows, dim);
|
||||
unused_input_.CopyFromMat(input_data.Range(num_rows - unused_num_rows,
|
||||
unused_num_rows, 0, dim));
|
||||
}
|
||||
|
||||
void NnetOnlineComputer::Flush(CuMatrix<BaseFloat> *output) {
|
||||
KALDI_ASSERT(!finished_);
|
||||
|
||||
int32 right_context = (pad_input_ ? nnet_.RightContext() : 0);
|
||||
int32 num_rows = unused_input_.NumRows() + right_context;
|
||||
|
||||
// If no frame needs to be computed, set output to empty and return.
|
||||
if (num_rows == 0 || unused_input_.NumRows() == 0) {
|
||||
output->Resize(0, 0);
|
||||
finished_ = true;
|
||||
return;
|
||||
}
|
||||
|
||||
int32 dim = unused_input_.NumCols();
|
||||
CuMatrix<BaseFloat> &input_data(data_[0]);
|
||||
input_data.Resize(num_rows, dim);
|
||||
input_data.Range(0, unused_input_.NumRows(),
|
||||
0, dim).CopyFromMat(unused_input_);
|
||||
if (right_context > 0) {
|
||||
int32 last_row = unused_input_.NumRows() - 1;
|
||||
for (int32 i = 0; i < right_context; i++)
|
||||
input_data.Row(num_rows - i - 1).CopyFromVec(unused_input_.Row(last_row));
|
||||
}
|
||||
|
||||
nnet_.ComputeChunkInfo(num_rows, 1, &chunk_info_);
|
||||
Propagate();
|
||||
*output = data_.back();
|
||||
|
||||
finished_ = true;
|
||||
}
|
||||
|
||||
void NnetOnlineComputer::Propagate() {
|
||||
for (int32 c = 0; c < nnet_.NumComponents(); ++c) {
|
||||
const Component &component = nnet_.GetComponent(c);
|
||||
CuMatrix<BaseFloat> &input_data = data_[c], &output_data = data_[c + 1];
|
||||
component.Propagate(chunk_info_[c], chunk_info_[c + 1],
|
||||
input_data, &output_data);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // namespace nnet2
|
||||
} // namespace kaldi
|
|
@ -0,0 +1,99 @@
|
|||
// nnet2/nnet-compute-online.h
|
||||
|
||||
// Copyright 2014 Johns Hopkins University (author: Daniel Povey)
|
||||
// Guoguo Chen
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_NNET2_NNET_COMPUTE_ONLINE_H_
|
||||
#define KALDI_NNET2_NNET_COMPUTE_ONLINE_H_
|
||||
|
||||
#include "nnet2/nnet-nnet.h"
|
||||
|
||||
namespace kaldi {
|
||||
namespace nnet2 {
|
||||
|
||||
/* This header provides functionality for doing forward computation in a situation
|
||||
where you want to start from the beginning of a file and progressively compute
|
||||
more, while re-using the hidden parts that (due to context) may be shared.
|
||||
(note: this sharing is more of an issue in multi-splice networks where there is
|
||||
splicing over time in the middle layers of the network).
|
||||
Note: this doesn't do the final taking-the-log and correcting for the prior.
|
||||
The current implementation is just an inefficient placeholder implementation;
|
||||
later we'll modify it to properly use previously computed activations.
|
||||
*/
|
||||
|
||||
class NnetOnlineComputer {
|
||||
|
||||
public:
|
||||
// All the inputs and outputs are of type CuMatrix, in case we're doing the
|
||||
// computation on the GPU (of course, if there is no GPU, it backs off to
|
||||
// using the CPU).
|
||||
// You should initialize an object of this type for each utterance you want
|
||||
// to decode.
|
||||
|
||||
// Note: pad_input will normally be true; it means that at the start and end
|
||||
// of the file, we pad with repeats of the first/last frame, so that the total
|
||||
// number of frames it outputs is the same as the number of input frames.
|
||||
NnetOnlineComputer(const Nnet &nnet,
|
||||
bool pad_input);
|
||||
|
||||
// This function works as follows: given a chunk of input (interpreted
|
||||
// as following in time any previously supplied data), do the computation
|
||||
// and produce all the frames of output we can. In the middle of the
|
||||
// file, the dimensions of input and output will be the same, but at
|
||||
// the beginning of the file, output will have fewer frames than input
|
||||
// due to required context.
|
||||
// It is the responsibility of the user to keep track of frame indices, if
|
||||
// required. This class won't output any frame twice.
|
||||
void Compute(const CuMatrixBase<BaseFloat> &input,
|
||||
CuMatrix<BaseFloat> *output);
|
||||
|
||||
// This flushes out the last frames of output; you call this when all
|
||||
// input has finished. It's invalid to call Compute or Flush after
|
||||
// calling Flush. It's valid to call Flush if no frames have been
|
||||
// input or if no frames have been output; this produces empty output.
|
||||
void Flush(CuMatrix<BaseFloat> *output);
|
||||
|
||||
private:
|
||||
void Propagate();
|
||||
|
||||
const Nnet &nnet_;
|
||||
|
||||
// data_ contains the intermediate stages and the output of the most recent
|
||||
// computation.
|
||||
std::vector<CuMatrix<BaseFloat> > data_;
|
||||
|
||||
std::vector<ChunkInfo> chunk_info_;
|
||||
|
||||
CuMatrix<BaseFloat> unused_input_;
|
||||
|
||||
bool pad_input_;
|
||||
|
||||
bool is_first_chunk_;
|
||||
|
||||
bool finished_;
|
||||
// we might need more variables here to keep track of how many frames we
|
||||
// already output from data_.
|
||||
|
||||
KALDI_DISALLOW_COPY_AND_ASSIGN(NnetOnlineComputer);
|
||||
};
|
||||
|
||||
|
||||
} // namespace nnet2
|
||||
} // namespace kaldi
|
||||
|
||||
#endif // KALDI_NNET2_NNET_COMPUTE_ONLINE_H_
|
|
@ -30,6 +30,10 @@
|
|||
namespace kaldi {
|
||||
namespace nnet2 {
|
||||
|
||||
// Note: see also nnet-compute-online.h, which provides a different
|
||||
// (lower-level) interface and more efficient for progressive evaluation of an
|
||||
// nnet throughout an utterance, with re-use of already-computed activations.
|
||||
|
||||
struct DecodableNnet2OnlineOptions {
|
||||
BaseFloat acoustic_scale;
|
||||
bool pad_input;
|
||||
|
@ -55,6 +59,12 @@ struct DecodableNnet2OnlineOptions {
|
|||
};
|
||||
|
||||
|
||||
/**
|
||||
This Decodable object for class nnet2::AmNnet takes feature input from class
|
||||
OnlineFeatureInterface, unlike, say, class DecodableAmNnet which takes
|
||||
feature input from a matrix.
|
||||
*/
|
||||
|
||||
class DecodableNnet2Online: public DecodableInterface {
|
||||
public:
|
||||
DecodableNnet2Online(const AmNnet &nnet,
|
||||
|
|
|
@ -7,7 +7,8 @@ TESTFILES =
|
|||
|
||||
OBJFILES = online-gmm-decodable.o online-feature-pipeline.o online-ivector-feature.o \
|
||||
online-nnet2-feature-pipeline.o online-gmm-decoding.o online-timing.o \
|
||||
online-endpoint.o onlinebin-util.o online-speex-wrapper.o online-nnet2-decoding.o
|
||||
online-endpoint.o onlinebin-util.o online-speex-wrapper.o \
|
||||
online-nnet2-decoding.o online-nnet2-decoding-threaded.o
|
||||
|
||||
LIBNAME = kaldi-online2
|
||||
|
||||
|
|
|
@ -0,0 +1,652 @@
|
|||
// online2/online-nnet2-decoding-threaded.cc
|
||||
|
||||
// Copyright 2013-2014 Johns Hopkins University (author: Daniel Povey)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "online2/online-nnet2-decoding-threaded.h"
|
||||
#include "nnet2/nnet-compute-online.h"
|
||||
#include "lat/lattice-functions.h"
|
||||
#include "lat/determinize-lattice-pruned.h"
|
||||
#include "thread/kaldi-thread.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
ThreadSynchronizer::ThreadSynchronizer():
|
||||
abort_(false),
|
||||
producer_waiting_(false),
|
||||
consumer_waiting_(false),
|
||||
num_errors_(0) {
|
||||
producer_semaphore_.Signal();
|
||||
consumer_semaphore_.Signal();
|
||||
}
|
||||
|
||||
bool ThreadSynchronizer::Lock(ThreadType t) {
|
||||
if (abort_)
|
||||
return false;
|
||||
if (t == ThreadSynchronizer::kProducer) {
|
||||
producer_semaphore_.Wait();
|
||||
} else {
|
||||
consumer_semaphore_.Wait();
|
||||
}
|
||||
if (abort_)
|
||||
return false;
|
||||
mutex_.Lock();
|
||||
held_by_ = t;
|
||||
if (abort_) {
|
||||
mutex_.Unlock();
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool ThreadSynchronizer::UnlockSuccess(ThreadType t) {
|
||||
if (t == ThreadSynchronizer::kProducer) {
|
||||
producer_semaphore_.Signal(); // next Lock won't wait.
|
||||
if (consumer_waiting_) {
|
||||
consumer_semaphore_.Signal();
|
||||
consumer_waiting_ = false;
|
||||
}
|
||||
} else {
|
||||
consumer_semaphore_.Signal(); // next Lock won't wait.
|
||||
if (producer_waiting_) {
|
||||
producer_semaphore_.Signal();
|
||||
producer_waiting_ = false;
|
||||
}
|
||||
|
||||
}
|
||||
mutex_.Unlock();
|
||||
return !abort_;
|
||||
}
|
||||
|
||||
bool ThreadSynchronizer::UnlockFailure(ThreadType t) {
|
||||
|
||||
KALDI_ASSERT(held_by_ == t && "Code error: unlocking a mutex you don't hold.");
|
||||
|
||||
if (t == ThreadSynchronizer::kProducer) {
|
||||
KALDI_ASSERT(!producer_waiting_ && "code error.");
|
||||
producer_waiting_ = true;
|
||||
} else {
|
||||
KALDI_ASSERT(!consumer_waiting_ && "code error.");
|
||||
consumer_waiting_ = true;
|
||||
}
|
||||
mutex_.Unlock();
|
||||
return !abort_;
|
||||
}
|
||||
|
||||
void ThreadSynchronizer::SetAbort() {
|
||||
abort_ = true;
|
||||
// we signal the semaphores just in case someone was waiting on either of
|
||||
// them.
|
||||
producer_semaphore_.Signal();
|
||||
consumer_semaphore_.Signal();
|
||||
}
|
||||
|
||||
ThreadSynchronizer::~ThreadSynchronizer() {
|
||||
}
|
||||
|
||||
// static
|
||||
void OnlineNnet2DecodingThreadedConfig::Check() {
|
||||
KALDI_ASSERT(max_buffered_features > 1);
|
||||
KALDI_ASSERT(feature_batch_size > 0);
|
||||
KALDI_ASSERT(max_loglikes_copy >= 0);
|
||||
KALDI_ASSERT(nnet_batch_size > 0);
|
||||
KALDI_ASSERT(decode_batch_size >= 1);
|
||||
}
|
||||
|
||||
|
||||
SingleUtteranceNnet2DecoderThreaded::SingleUtteranceNnet2DecoderThreaded(
|
||||
const OnlineNnet2DecodingThreadedConfig &config,
|
||||
const TransitionModel &tmodel,
|
||||
const nnet2::AmNnet &am_nnet,
|
||||
const fst::Fst<fst::StdArc> &fst,
|
||||
const OnlineNnet2FeaturePipelineInfo &feature_info,
|
||||
const OnlineIvectorExtractorAdaptationState &adaptation_state):
|
||||
config_(config), am_nnet_(am_nnet), tmodel_(tmodel), sampling_rate_(0.0),
|
||||
num_samples_received_(0), input_finished_(false),
|
||||
feature_pipeline_(feature_info), feature_buffer_start_frame_(0),
|
||||
feature_buffer_finished_(false), decodable_(tmodel),
|
||||
num_frames_decoded_(0), decoder_(fst, config_.decoder_opts),
|
||||
abort_(false), error_(false) {
|
||||
// if the user supplies an adaptation state that was not freshly initialized,
|
||||
// it means that we take the adaptation state from the previous
|
||||
// utterance(s)... this only makes sense if theose previous utterance(s) are
|
||||
// believed to be from the same speaker.
|
||||
feature_pipeline_.SetAdaptationState(adaptation_state);
|
||||
// spawn threads.
|
||||
|
||||
pthread_attr_t pthread_attr;
|
||||
pthread_attr_init(&pthread_attr);
|
||||
int32 ret;
|
||||
|
||||
// Note: if the constructor throws an exception, the corresponding destructor
|
||||
// will not be called. So we don't have to be careful about setting the
|
||||
// thread pointers to NULL after we've joined them.
|
||||
if ((ret=pthread_create(&(threads_[0]),
|
||||
&pthread_attr, RunFeatureExtraction,
|
||||
(void*)this)) != 0) {
|
||||
const char *c = strerror(ret);
|
||||
if (c == NULL) { c = "[NULL]"; }
|
||||
KALDI_ERR << "Error creating thread, errno was: " << c;
|
||||
}
|
||||
if ((ret=pthread_create(&(threads_[1]),
|
||||
&pthread_attr, RunNnetEvaluation,
|
||||
(void*)this)) != 0) {
|
||||
const char *c = strerror(ret);
|
||||
if (c == NULL) { c = "[NULL]"; }
|
||||
bool error = true;
|
||||
AbortAllThreads(error);
|
||||
KALDI_WARN << "Error creating thread, errno was: " << c
|
||||
<< " (will rejoin already-created threads).";
|
||||
if (pthread_join(threads_[0], NULL)) {
|
||||
KALDI_ERR << "Error rejoining thread.";
|
||||
} else {
|
||||
KALDI_ERR << "Error creating thread, errno was: " << c;
|
||||
}
|
||||
}
|
||||
if ((ret=pthread_create(&(threads_[2]),
|
||||
&pthread_attr, RunDecoderSearch,
|
||||
(void*)this)) != 0) {
|
||||
const char *c = strerror(ret);
|
||||
if (c == NULL) { c = "[NULL]"; }
|
||||
bool error = true;
|
||||
AbortAllThreads(error);
|
||||
KALDI_WARN << "Error creating thread, errno was: " << c
|
||||
<< " (will rejoin already-created threads).";
|
||||
if (pthread_join(threads_[0], NULL) || pthread_join(threads_[1], NULL)) {
|
||||
KALDI_ERR << "Error rejoining thread.";
|
||||
} else {
|
||||
KALDI_ERR << "Error creating thread, errno was: " << c;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
SingleUtteranceNnet2DecoderThreaded::~SingleUtteranceNnet2DecoderThreaded() {
|
||||
if (!abort_) {
|
||||
// If we have not already started the process of aborting the threads, do so now.
|
||||
bool error = false;
|
||||
AbortAllThreads(error);
|
||||
}
|
||||
// join all the threads (this avoids leaving zombie threads around, or threads
|
||||
// that might be accessing deconstructed object).
|
||||
WaitForAllThreads();
|
||||
DeletePointers(&input_waveform_);
|
||||
DeletePointers(&feature_buffer_);
|
||||
}
|
||||
|
||||
void SingleUtteranceNnet2DecoderThreaded::AcceptWaveform(
|
||||
BaseFloat sampling_rate,
|
||||
const VectorBase<BaseFloat> &wave_part) {
|
||||
if (sampling_rate_ <= 0.0)
|
||||
sampling_rate_ = sampling_rate;
|
||||
else {
|
||||
KALDI_ASSERT(sampling_rate == sampling_rate_);
|
||||
}
|
||||
num_samples_received_ += wave_part.Dim();
|
||||
|
||||
if (wave_part.Dim() == 0) return;
|
||||
if (!waveform_synchronizer_.Lock(ThreadSynchronizer::kProducer)) {
|
||||
KALDI_ERR << "Failure locking mutex: decoding aborted.";
|
||||
}
|
||||
|
||||
Vector<BaseFloat> *new_part = new Vector<BaseFloat>(wave_part);
|
||||
input_waveform_.push_back(new_part);
|
||||
// we always unlock with success because there is no buffer size limitation
|
||||
// for the waveform so no reason why we might wait.
|
||||
waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kProducer);
|
||||
}
|
||||
|
||||
int32 SingleUtteranceNnet2DecoderThreaded::NumFramesReceivedApprox() const {
|
||||
return num_samples_received_ /
|
||||
(sampling_rate_ * feature_pipeline_.FrameShiftInSeconds());
|
||||
}
|
||||
|
||||
void SingleUtteranceNnet2DecoderThreaded::InputFinished() {
|
||||
// setting input_finished_ = true informs the feature-processing pipeline
|
||||
// to expect no more input, and to flush out the last few frames if there
|
||||
// is any latency in the pipeline (e.g. due to pitch).
|
||||
if (!waveform_synchronizer_.Lock(ThreadSynchronizer::kProducer)) {
|
||||
KALDI_ERR << "Failure locking mutex: decoding aborted.";
|
||||
}
|
||||
KALDI_ASSERT(!input_finished_ && "InputFinished called twice");
|
||||
input_finished_ = true;
|
||||
waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kProducer);
|
||||
}
|
||||
|
||||
void SingleUtteranceNnet2DecoderThreaded::TerminateDecoding() {
|
||||
bool error = false;
|
||||
AbortAllThreads(error);
|
||||
}
|
||||
|
||||
void SingleUtteranceNnet2DecoderThreaded::Wait() {
|
||||
if (!input_finished_ && !abort_) {
|
||||
KALDI_ERR << "You cannot call Wait() before calling either InputFinished() "
|
||||
<< "or TerminateDecoding().";
|
||||
}
|
||||
WaitForAllThreads();
|
||||
}
|
||||
|
||||
void SingleUtteranceNnet2DecoderThreaded::FinalizeDecoding() {
|
||||
if (KALDI_PTHREAD_PTR(threads_[0]) != 0) {
|
||||
KALDI_ERR << "It is an error to call FinalizeDecoding before Wait().";
|
||||
}
|
||||
decoder_.FinalizeDecoding();
|
||||
}
|
||||
|
||||
|
||||
void SingleUtteranceNnet2DecoderThreaded::GetAdaptationState(
|
||||
OnlineIvectorExtractorAdaptationState *adaptation_state) {
|
||||
feature_pipeline_mutex_.Lock(); // If this blocks, it shouldn't be for very long.
|
||||
feature_pipeline_.GetAdaptationState(adaptation_state);
|
||||
feature_pipeline_mutex_.Unlock(); // If this blocks, it won't be for very long.
|
||||
}
|
||||
|
||||
void SingleUtteranceNnet2DecoderThreaded::GetLattice(
|
||||
bool end_of_utterance,
|
||||
CompactLattice *clat,
|
||||
BaseFloat *final_relative_cost) const {
|
||||
clat->DeleteStates();
|
||||
// we'll make an exception to the normal const rules, for mutexes, since
|
||||
// we're not really changing the class.
|
||||
const_cast<Mutex&>(decoder_mutex_).Lock();
|
||||
if (final_relative_cost != NULL)
|
||||
*final_relative_cost = decoder_.FinalRelativeCost();
|
||||
if (decoder_.NumFramesDecoded() == 0) {
|
||||
const_cast<Mutex&>(decoder_mutex_).Unlock();
|
||||
clat->SetFinal(clat->AddState(),
|
||||
CompactLatticeWeight::One());
|
||||
return;
|
||||
}
|
||||
Lattice raw_lat;
|
||||
decoder_.GetRawLattice(&raw_lat, end_of_utterance);
|
||||
const_cast<Mutex&>(decoder_mutex_).Unlock();
|
||||
|
||||
if (!config_.decoder_opts.determinize_lattice)
|
||||
KALDI_ERR << "--determinize-lattice=false option is not supported at the moment";
|
||||
|
||||
BaseFloat lat_beam = config_.decoder_opts.lattice_beam;
|
||||
DeterminizeLatticePhonePrunedWrapper(
|
||||
tmodel_, &raw_lat, lat_beam, clat, config_.decoder_opts.det_opts);
|
||||
}
|
||||
|
||||
void SingleUtteranceNnet2DecoderThreaded::GetBestPath(
|
||||
bool end_of_utterance,
|
||||
Lattice *best_path,
|
||||
BaseFloat *final_relative_cost) const {
|
||||
// we'll make an exception to the normal const rules, for mutexes, since
|
||||
// we're not really changing the class.
|
||||
const_cast<Mutex&>(decoder_mutex_).Lock();
|
||||
if (decoder_.NumFramesDecoded() == 0) {
|
||||
// It's possible that this if-statement is not necessary because we'd get this
|
||||
// anyway if we just called GetBestPath on the decoder.
|
||||
best_path->DeleteStates();
|
||||
best_path->SetFinal(best_path->AddState(),
|
||||
LatticeWeight::One());
|
||||
if (final_relative_cost != NULL)
|
||||
*final_relative_cost = std::numeric_limits<BaseFloat>::infinity();
|
||||
} else {
|
||||
decoder_.GetBestPath(best_path,
|
||||
end_of_utterance);
|
||||
if (final_relative_cost != NULL)
|
||||
*final_relative_cost = decoder_.FinalRelativeCost();
|
||||
}
|
||||
const_cast<Mutex&>(decoder_mutex_).Unlock();
|
||||
}
|
||||
|
||||
void SingleUtteranceNnet2DecoderThreaded::AbortAllThreads(bool error) {
|
||||
abort_ = true;
|
||||
if (error)
|
||||
error_ = true;
|
||||
waveform_synchronizer_.SetAbort();
|
||||
feature_synchronizer_.SetAbort();
|
||||
decodable_synchronizer_.SetAbort();
|
||||
}
|
||||
|
||||
int32 SingleUtteranceNnet2DecoderThreaded::NumFramesDecoded() const {
|
||||
const_cast<Mutex&>(decoder_mutex_).Lock();
|
||||
int32 ans = decoder_.NumFramesDecoded();
|
||||
const_cast<Mutex&>(decoder_mutex_).Unlock();
|
||||
return ans;
|
||||
}
|
||||
|
||||
void* SingleUtteranceNnet2DecoderThreaded::RunFeatureExtraction(void *ptr_in) {
|
||||
SingleUtteranceNnet2DecoderThreaded *me =
|
||||
reinterpret_cast<SingleUtteranceNnet2DecoderThreaded*>(ptr_in);
|
||||
try {
|
||||
if (!me->RunFeatureExtractionInternal() && !me->abort_)
|
||||
KALDI_ERR << "Returned abnormally and abort was not called";
|
||||
} catch(const std::exception &e) {
|
||||
KALDI_WARN << "Caught exception: " << e.what();
|
||||
// if an error happened in one thread, we need to make sure the other
|
||||
// threads can exit too.
|
||||
bool error = true;
|
||||
me->AbortAllThreads(error);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
void* SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluation(void *ptr_in) {
|
||||
SingleUtteranceNnet2DecoderThreaded *me =
|
||||
reinterpret_cast<SingleUtteranceNnet2DecoderThreaded*>(ptr_in);
|
||||
try {
|
||||
if (!me->RunNnetEvaluationInternal() && !me->abort_)
|
||||
KALDI_ERR << "Returned abnormally and abort was not called";
|
||||
} catch(const std::exception &e) {
|
||||
KALDI_WARN << "Caught exception: " << e.what();
|
||||
// if an error happened in one thread, we need to make sure the other
|
||||
// threads can exit too.
|
||||
bool error = true;
|
||||
me->AbortAllThreads(error);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
void* SingleUtteranceNnet2DecoderThreaded::RunDecoderSearch(void *ptr_in) {
|
||||
SingleUtteranceNnet2DecoderThreaded *me =
|
||||
reinterpret_cast<SingleUtteranceNnet2DecoderThreaded*>(ptr_in);
|
||||
try {
|
||||
if (!me->RunDecoderSearchInternal() && !me->abort_)
|
||||
KALDI_ERR << "Returned abnormally and abort was not called";
|
||||
} catch(const std::exception &e) {
|
||||
KALDI_WARN << "Caught exception: " << e.what();
|
||||
// if an error happened in one thread, we need to make sure the other threads can exit too.
|
||||
bool error = true;
|
||||
me->AbortAllThreads(error);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
void SingleUtteranceNnet2DecoderThreaded::WaitForAllThreads() {
|
||||
for (int32 i = 0; i < 3; i++) { // there are 3 spawned threads.
|
||||
pthread_t &thread = threads_[i];
|
||||
if (KALDI_PTHREAD_PTR(thread) != 0) {
|
||||
if (pthread_join(thread, NULL)) {
|
||||
KALDI_ERR << "Error rejoining thread"; // this should not happen.
|
||||
}
|
||||
KALDI_PTHREAD_PTR(thread) = 0;
|
||||
}
|
||||
}
|
||||
if (error_) {
|
||||
KALDI_ERR << "Error encountered during decoding. See above.";
|
||||
}
|
||||
}
|
||||
|
||||
bool SingleUtteranceNnet2DecoderThreaded::RunFeatureExtractionInternal() {
|
||||
// Note: if any of the functions Lock, UnlockSuccess, UnlockFailure return
|
||||
// false, it is because AbortAllThreads() called, and we return false
|
||||
// immediately.
|
||||
|
||||
// num_frames_output is a local variable that keeps track of how many
|
||||
// frames we have output to the feature buffer, for this utterance.
|
||||
int32 num_frames_output = 0;
|
||||
|
||||
while (true) {
|
||||
// First deal with accepting input.
|
||||
if (!waveform_synchronizer_.Lock(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
if (input_waveform_.empty()) {
|
||||
if (input_finished_ &&
|
||||
!feature_pipeline_.IsLastFrame(feature_pipeline_.NumFramesReady()-1)) {
|
||||
// the main thread called InputFinished() and set input_finished_, and
|
||||
// we haven't yet registered that fact. This is progress so
|
||||
// UnlockSuccess().
|
||||
feature_pipeline_.InputFinished();
|
||||
if (!waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
} else {
|
||||
// there was no input to process. However, we only call UnlockFailure() if we
|
||||
// are blocked on the fact that there is no input to process; otherwise we
|
||||
// call UnlockSuccess().
|
||||
if (num_frames_output == feature_pipeline_.NumFramesReady()) {
|
||||
// we need to wait until there is more input.
|
||||
if (!waveform_synchronizer_.UnlockFailure(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
} else { // we can keep looping.
|
||||
if (!waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else { // there is more wav data.
|
||||
{ // braces clarify scope of locking.
|
||||
feature_pipeline_mutex_.Lock();
|
||||
for (size_t i = 0; i < input_waveform_.size(); i++)
|
||||
if (input_waveform_[i]->Dim() != 0)
|
||||
feature_pipeline_.AcceptWaveform(sampling_rate_, *input_waveform_[i]);
|
||||
feature_pipeline_mutex_.Unlock();
|
||||
}
|
||||
DeletePointers(&input_waveform_);
|
||||
input_waveform_.clear();
|
||||
if (!waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!feature_synchronizer_.Lock(ThreadSynchronizer::kProducer)) return false;
|
||||
|
||||
if (feature_buffer_.size() >= config_.max_buffered_features) {
|
||||
// we need to block on the output buffer.
|
||||
if (!feature_synchronizer_.UnlockFailure(ThreadSynchronizer::kProducer))
|
||||
return false;
|
||||
} else {
|
||||
|
||||
{ // braces clarify scope of locking.
|
||||
feature_pipeline_mutex_.Lock();
|
||||
// There is buffer space available; deal with producing output.
|
||||
int32 cur_size = feature_buffer_.size(),
|
||||
batch_size = config_.feature_batch_size,
|
||||
feat_dim = feature_pipeline_.Dim();
|
||||
|
||||
for (int32 t = feature_buffer_start_frame_ +
|
||||
static_cast<int32>(feature_buffer_.size());
|
||||
t < feature_buffer_start_frame_ + config_.max_buffered_features &&
|
||||
t < feature_buffer_start_frame_ + cur_size + batch_size &&
|
||||
t < feature_pipeline_.NumFramesReady(); t++) {
|
||||
Vector<BaseFloat> *feats = new Vector<BaseFloat>(feat_dim, kUndefined);
|
||||
// Note: most of the actual computation occurs.
|
||||
feature_pipeline_.GetFrame(t, feats);
|
||||
feature_buffer_.push_back(feats);
|
||||
}
|
||||
num_frames_output = feature_buffer_start_frame_ + feature_buffer_.size();
|
||||
if (feature_pipeline_.IsLastFrame(num_frames_output - 1)) {
|
||||
// e.g. user called InputFinished() and we already saw the last frame.
|
||||
feature_buffer_finished_ = true;
|
||||
}
|
||||
feature_pipeline_mutex_.Unlock();
|
||||
}
|
||||
if (!feature_synchronizer_.UnlockSuccess(ThreadSynchronizer::kProducer))
|
||||
return false;
|
||||
|
||||
if (feature_buffer_finished_) return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluationInternal() {
|
||||
// if any of the Lock/Unlock functions return false, it's because AbortAllThreads()
|
||||
// was called.
|
||||
|
||||
// This object is responsible for keeping track of the context, and avoiding
|
||||
// re-computing things we've already computed.
|
||||
bool pad_input = true;
|
||||
nnet2::NnetOnlineComputer computer(am_nnet_.GetNnet(), pad_input);
|
||||
|
||||
// we declare the following as CuVector just to enable GPU support, but
|
||||
// we expect this code to be run on CPU in the normal case.
|
||||
CuVector<BaseFloat> log_inv_priors(am_nnet_.Priors());
|
||||
log_inv_priors.ApplyFloor(1.0e-20); // should have no effect.
|
||||
log_inv_priors.ApplyLog();
|
||||
log_inv_priors.Scale(-1.0);
|
||||
|
||||
int32 num_frames_output = 0;
|
||||
|
||||
while (true) {
|
||||
bool last_time = false;
|
||||
|
||||
if (!feature_synchronizer_.Lock(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
|
||||
CuMatrix<BaseFloat> cu_loglikes;
|
||||
|
||||
if (feature_buffer_.empty()) {
|
||||
if (feature_buffer_finished_) {
|
||||
// flush out the last few frames. Note: this is the only place from
|
||||
// which we check feature_buffer_finished_, and we'll exit the loop, so
|
||||
// if we reach here it must be the first time it was true.
|
||||
last_time = true;
|
||||
if (!feature_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
computer.Flush(&cu_loglikes);
|
||||
} else {
|
||||
// there is nothing to do because there is no input. Next call to Lock
|
||||
// should block till the feature-processing thread does something.
|
||||
if (!feature_synchronizer_.UnlockFailure(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
int32 num_frames_evaluate = std::min<int32>(feature_buffer_.size(),
|
||||
config_.nnet_batch_size),
|
||||
feat_dim = feature_buffer_[0]->Dim();
|
||||
Matrix<BaseFloat> feats(num_frames_evaluate, feat_dim);
|
||||
for (int32 i = 0; i < num_frames_evaluate; i++) {
|
||||
feats.Row(i).CopyFromVec(*(feature_buffer_[i]));
|
||||
delete feature_buffer_[i];
|
||||
}
|
||||
feature_buffer_.erase(feature_buffer_.begin(),
|
||||
feature_buffer_.begin() + num_frames_evaluate);
|
||||
feature_buffer_start_frame_ += num_frames_evaluate;
|
||||
if (!feature_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
|
||||
CuMatrix<BaseFloat> cu_feats;
|
||||
cu_feats.Swap(&feats); // If we don't have a GPU (and not having a GPU is
|
||||
// the normal expected use-case for this code),
|
||||
// this would be a lightweight operation, swapping
|
||||
// pointers.
|
||||
|
||||
KALDI_VLOG(4) << "Computing chunk of " << cu_feats.NumRows() << " frames "
|
||||
<< "of nnet.";
|
||||
computer.Compute(cu_feats, &cu_loglikes);
|
||||
cu_loglikes.ApplyFloor(1.0e-20);
|
||||
cu_loglikes.ApplyLog();
|
||||
// take the log-posteriors and turn them into pseudo-log-likelihoods by
|
||||
// dividing by the pdf priors; then scale by the acoustic scale.
|
||||
if (cu_loglikes.NumRows() != 0) {
|
||||
cu_loglikes.AddVecToRows(1.0, log_inv_priors);
|
||||
cu_loglikes.Scale(config_.acoustic_scale);
|
||||
}
|
||||
}
|
||||
|
||||
Matrix<BaseFloat> loglikes;
|
||||
loglikes.Swap(&cu_loglikes); // If we don't have a GPU (and not having a
|
||||
// GPU is the normal expected use-case for
|
||||
// this code), this would be a lightweight
|
||||
// operation, swapping pointers.
|
||||
|
||||
|
||||
// OK, at this point we may have some newly created log-likes and we want to
|
||||
// give them to the decoding thread.
|
||||
|
||||
int32 num_loglike_frames = loglikes.NumRows();
|
||||
|
||||
if (loglikes.NumRows() != 0) { // if we need to output some loglikes...
|
||||
while (true) {
|
||||
// we may have to grab and release the decodable mutex
|
||||
// a few times before it's ready to accept the loglikes.
|
||||
if (!decodable_synchronizer_.Lock(ThreadSynchronizer::kProducer))
|
||||
return false;
|
||||
int32 num_frames_decoded = num_frames_decoded_;
|
||||
// we can't have output fewer frames than were decoded.
|
||||
KALDI_ASSERT(num_frames_output >= num_frames_decoded);
|
||||
if (num_frames_output - num_frames_decoded < config_.max_loglikes_copy) {
|
||||
// If we would have to copy fewer than config_.max_loglikes_copy
|
||||
// previously evaluated log-likelihoods inside the decodable object..
|
||||
int32 frames_to_discard = num_frames_decoded_ -
|
||||
decodable_.FirstAvailableFrame();
|
||||
KALDI_ASSERT(frames_to_discard >= 0);
|
||||
num_frames_output += num_loglike_frames;
|
||||
decodable_.AcceptLoglikes(&loglikes, frames_to_discard);
|
||||
if (!decodable_synchronizer_.UnlockSuccess(ThreadSynchronizer::kProducer))
|
||||
return false;
|
||||
break; // break from the innermost while loop.
|
||||
} else {
|
||||
// we want the next call to Lock to block until the decoder has
|
||||
// processed more frames.
|
||||
if (!decodable_synchronizer_.UnlockFailure(ThreadSynchronizer::kProducer))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (last_time) {
|
||||
// Inform the decodable object that there will be no more input.
|
||||
if (!decodable_synchronizer_.Lock(ThreadSynchronizer::kProducer))
|
||||
return false;
|
||||
decodable_.InputIsFinished();
|
||||
if (!decodable_synchronizer_.UnlockSuccess(ThreadSynchronizer::kProducer))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
bool SingleUtteranceNnet2DecoderThreaded::RunDecoderSearchInternal() {
|
||||
int32 num_frames_decoded = 0; // this is just a copy of decoder_->NumFramesDecoded();
|
||||
decoder_.InitDecoding();
|
||||
while (true) { // decode at most one frame each loop.
|
||||
if (!decodable_synchronizer_.Lock(ThreadSynchronizer::kConsumer))
|
||||
return false; // AbortAllThreads() called.
|
||||
if (decodable_.NumFramesReady() <= num_frames_decoded) {
|
||||
// no frames available to decode.
|
||||
KALDI_ASSERT(decodable_.NumFramesReady() == num_frames_decoded);
|
||||
if (decodable_.IsLastFrame(num_frames_decoded - 1)) {
|
||||
decodable_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer);
|
||||
return true; // exit from this thread; we're done.
|
||||
} else {
|
||||
// we were not able to advance the decoding due to no available
|
||||
// input. The next call will ensure that the next call to
|
||||
// decodable_synchronizer_.Lock() will wait.
|
||||
if (!decodable_synchronizer_.UnlockFailure(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// Decode at most config_.decode_batch_size frames (e.g. 1 or 2).
|
||||
decoder_mutex_.Lock();
|
||||
decoder_.AdvanceDecoding(&decodable_, config_.decode_batch_size);
|
||||
num_frames_decoded = decoder_.NumFramesDecoded();
|
||||
decoder_mutex_.Unlock();
|
||||
num_frames_decoded_ = num_frames_decoded;
|
||||
if (!decodable_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool SingleUtteranceNnet2DecoderThreaded::EndpointDetected(
|
||||
const OnlineEndpointConfig &config) {
|
||||
decoder_mutex_.Lock();
|
||||
bool ans = kaldi::EndpointDetected(config, tmodel_,
|
||||
feature_pipeline_.FrameShiftInSeconds(),
|
||||
decoder_);
|
||||
decoder_mutex_.Unlock();
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // namespace kaldi
|
||||
|
|
@ -0,0 +1,415 @@
|
|||
// online2/online-nnet2-decoding-threaded.h
|
||||
|
||||
// Copyright 2014-2015 Johns Hopkins University (author: Daniel Povey)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#ifndef KALDI_ONLINE2_ONLINE_NNET2_DECODING_THREADED_H_
|
||||
#define KALDI_ONLINE2_ONLINE_NNET2_DECODING_THREADED_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <deque>
|
||||
|
||||
#include "matrix/matrix-lib.h"
|
||||
#include "util/common-utils.h"
|
||||
#include "base/kaldi-error.h"
|
||||
#include "decoder/decodable-matrix.h"
|
||||
#include "nnet2/am-nnet.h"
|
||||
#include "online2/online-nnet2-feature-pipeline.h"
|
||||
#include "online2/online-endpoint.h"
|
||||
#include "decoder/lattice-faster-online-decoder.h"
|
||||
#include "hmm/transition-model.h"
|
||||
#include "thread/kaldi-mutex.h"
|
||||
#include "thread/kaldi-semaphore.h"
|
||||
|
||||
namespace kaldi {
|
||||
/// @addtogroup onlinedecoding OnlineDecoding
|
||||
/// @{
|
||||
|
||||
|
||||
/**
|
||||
class ThreadSynchronizer acts to guard an arbitrary type of buffer between a
|
||||
producing and a consuming thread (note: it's all symmetric between the two
|
||||
thread types). It has a similar interface to a mutex, except that instead of
|
||||
just Lock and Unlock, it has Lock, UnlockSuccess and UnlockFailure, and each
|
||||
function takes an argument kProducer or kConsumer to identify whether the
|
||||
producing or consuming thread is waiting.
|
||||
|
||||
The basic concept is that you lock the object; and if you discover the you're
|
||||
blocked because you're either trying to read an empty buffer or trying to
|
||||
write to a full buffer, you unlock with UnlockFailure; and this will cause
|
||||
your next call to Lock to block until the *other* thread has called Lock and
|
||||
then UnlockSuccess. However, if at that point the other thread calls Lock
|
||||
and then UnlockFailure, it is an error because you can't have both producing
|
||||
and consuming threads claiming that the buffer is full/empty. If you lock
|
||||
the object and were successful you call UnlockSuccess; and you call
|
||||
UnlockSuccess even if, for your own reasons, you ended up not changing the
|
||||
state of the buffer.
|
||||
*/
|
||||
class ThreadSynchronizer {
|
||||
public:
|
||||
ThreadSynchronizer();
|
||||
|
||||
// Most calls to this class should provide the thread-type of the caller,
|
||||
// producing or consuming. Actually the behavior of this class is symmetric
|
||||
// between the two types of thread.
|
||||
enum ThreadType { kProducer, kConsumer };
|
||||
|
||||
// All functions returning bool will return true normally, and false if
|
||||
// SetAbort() was set; if they return false, you should probably call SetAbort()
|
||||
// on any other ThreadSynchronizer classes you are using and then return from
|
||||
// the thread.
|
||||
|
||||
// call this to lock the object being guarded.
|
||||
bool Lock(ThreadType t);
|
||||
|
||||
// Call this to unlock the object being guarded, if you don't want the next call to
|
||||
// Lock to stall.
|
||||
bool UnlockSuccess(ThreadType t);
|
||||
|
||||
// Call this if you want the next call to Lock() to stall until the other
|
||||
// (producer/consumer) thread has locked and then unlocked the mutex. Note
|
||||
// that, if the other thread then calls Lock and then UnlockFailure, this will
|
||||
// generate a printed warning (and if repeated too many times, an exception).
|
||||
bool UnlockFailure(ThreadType t);
|
||||
|
||||
// Sets abort_ flag so future calls will return false, and future calls to
|
||||
// Lock() won't lock the mutex but will immediately return false.
|
||||
void SetAbort();
|
||||
|
||||
~ThreadSynchronizer();
|
||||
|
||||
private:
|
||||
bool abort_;
|
||||
bool producer_waiting_; // true if producer is/will be waiting on semaphore
|
||||
bool consumer_waiting_; // true if consumer is/will be waiting on semaphore
|
||||
Mutex mutex_; // Locks the buffer object.
|
||||
ThreadType held_by_; // Record of which thread is holding the mutex (if
|
||||
// held); else undefined. Used for validation of input.
|
||||
Semaphore producer_semaphore_; // The producer thread waits on this semaphore
|
||||
Semaphore consumer_semaphore_; // The consumer thread waits on this semaphore
|
||||
int32 num_errors_; // Rumber of times the threads alternated doing Lock() and
|
||||
// UnlockFailure(). This should not happen at all; but
|
||||
// it's more user-friendly to simply warn a few times; and then
|
||||
// only after a number of errors, to fail.
|
||||
KALDI_DISALLOW_COPY_AND_ASSIGN(ThreadSynchronizer);
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
// This is the configuration class for SingleUtteranceNnet2DecoderThreaded. The
|
||||
// actual command line program requires other configs that it creates
|
||||
// separately, and which are not included here: namely,
|
||||
// OnlineNnet2FeaturePipelineConfig and OnlineEndpointConfig.
|
||||
struct OnlineNnet2DecodingThreadedConfig {
|
||||
|
||||
LatticeFasterDecoderConfig decoder_opts;
|
||||
|
||||
BaseFloat acoustic_scale;
|
||||
|
||||
int32 max_buffered_features; // maximum frames of features we allow to be
|
||||
// held in the feature buffer before we block
|
||||
// the feature-processing thread.
|
||||
|
||||
int32 feature_batch_size; // maximum number of frames at a time that we decode
|
||||
// before unlocking the mutex. The only real cost
|
||||
// here is a mutex lock/unlock, so it's OK to make
|
||||
// this fairly small.
|
||||
int32 max_loglikes_copy; // maximum unused frames of log-likelihoods we will
|
||||
// copy from the decodable object back into another
|
||||
// matrix to be supplied to the decodable object.
|
||||
// make this too large-> will block the
|
||||
// decoder-search thread while copying; too small
|
||||
// -> the nnet-evaluation thread may get blocked
|
||||
// for too long while waiting for the decodable
|
||||
// thread to be ready.
|
||||
int32 nnet_batch_size; // batch size (number of frames) we evaluate in the
|
||||
// neural net, if this many is available. To take
|
||||
// best advantage of BLAS, you may want to set this
|
||||
// fairly large, e.g. 32 or 64 frames. It probably
|
||||
// makes sense to tune this a bit.
|
||||
int32 decode_batch_size; // maximum number of frames at a time that we decode
|
||||
// before unlocking the mutex. The only real cost
|
||||
// here is a mutex lock/unlock, so it's OK to make
|
||||
// this fairly small.
|
||||
|
||||
OnlineNnet2DecodingThreadedConfig() {
|
||||
acoustic_scale = 0.1;
|
||||
max_buffered_features = 100;
|
||||
feature_batch_size = 2;
|
||||
nnet_batch_size = 32;
|
||||
max_loglikes_copy = 20;
|
||||
decode_batch_size = 2;
|
||||
}
|
||||
|
||||
void Check();
|
||||
|
||||
void Register(OptionsItf *po) {
|
||||
decoder_opts.Register(po);
|
||||
po->Register("acoustic-scale", &acoustic_scale, "Scale used on acoustics "
|
||||
"when decoding");
|
||||
po->Register("max-buffered-features", &max_buffered_features, "Obscure "
|
||||
"setting, affects multi-threaded decoding.");
|
||||
po->Register("feature-batch-size", &max_buffered_features, "Obscure "
|
||||
"setting, affects multi-threaded decoding.");
|
||||
po->Register("nnet-batch-size", &nnet_batch_size, "Maximum batch size "
|
||||
"(in frames) used when evaluating neural net likelihoods");
|
||||
po->Register("max-loglikes-copy", &max_loglikes_copy, "Obscure "
|
||||
"setting, affects multi-threaded decoding.");
|
||||
po->Register("decode-batch-sie", &decode_batch_size, "Obscure "
|
||||
"setting, affects multi-threaded decoding.");
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
You will instantiate this class when you want to decode a single
|
||||
utterance using the online-decoding setup for neural nets. Each time this
|
||||
class is created, it creates three background threads, and the feature
|
||||
extraction, neural net evaluation, and search aspects of decoding all
|
||||
happen in different threads.
|
||||
Note: we assume that all calls to its public interface happen from a single
|
||||
thread.
|
||||
*/
|
||||
class SingleUtteranceNnet2DecoderThreaded {
|
||||
public:
|
||||
// Constructor. Unlike SingleUtteranceNnet2Decoder, we create the
|
||||
// feature_pipeline object inside this class, since access to it needs to be
|
||||
// controlled by a mutex and this class knows how to handle that. The
|
||||
// feature_info and adaptation_state arguments are used to initialize the
|
||||
// (locally owned) feature pipeline.
|
||||
SingleUtteranceNnet2DecoderThreaded(
|
||||
const OnlineNnet2DecodingThreadedConfig &config,
|
||||
const TransitionModel &tmodel,
|
||||
const nnet2::AmNnet &am_nnet,
|
||||
const fst::Fst<fst::StdArc> &fst,
|
||||
const OnlineNnet2FeaturePipelineInfo &feature_info,
|
||||
const OnlineIvectorExtractorAdaptationState &adaptation_state);
|
||||
|
||||
/// You call this to provide this class with more waveform to decode. This
|
||||
/// call is, for all practical purposes, non-blocking.
|
||||
void AcceptWaveform(BaseFloat samp_freq,
|
||||
const VectorBase<BaseFloat> &wave_part);
|
||||
|
||||
/// You call this to inform the class that no more waveform will be provided;
|
||||
/// this allows it to flush out the last few frames of features, and is
|
||||
/// necessary if you want to call Wait() to wait until all decoding is done.
|
||||
/// After calling InputFinished() you cannot call AcceptWaveform any more.
|
||||
void InputFinished();
|
||||
|
||||
/// You can call this if you don't want the decoding to proceed further with
|
||||
/// this utterance. It just won't do any more processing, but you can still
|
||||
/// use the lattice from the decoding that it's already done. Note: it may
|
||||
/// still continue decoding up to decode_batch_size (default: 2) frames of
|
||||
/// data before the decoding thread exits. You can call Wait() after calling
|
||||
/// this, if you want to wait for that.
|
||||
void TerminateDecoding();
|
||||
|
||||
/// This call will block until all the data has been decoded; it must only be
|
||||
/// called after either InputFinished() has been called or TerminateDecoding() has
|
||||
/// been called; otherwise, to call it is an error.
|
||||
void Wait();
|
||||
|
||||
/// Finalizes the decoding. Cleans up and prunes remaining tokens, so the final
|
||||
/// lattice is faster to obtain. May not be called unless either InputFinished()
|
||||
/// or TerminateDecoding() has been called. If InputFinished() was called, it
|
||||
/// calls Wait() to ensure that the decoding has finished (it's not an error
|
||||
/// if you already called Wait()).
|
||||
void FinalizeDecoding();
|
||||
|
||||
/// Returns *approximately* (ignoring end effects), the number of frames of
|
||||
/// data that we expect given the amount of data that the pipeline has
|
||||
/// received via AcceptWaveform(). (ignores small end effects). This might
|
||||
/// be useful in application code to compare with NumFramesDecoded() and gauge
|
||||
/// how much latency there is.
|
||||
int32 NumFramesReceivedApprox() const;
|
||||
|
||||
/// Returns the number of frames currently decoded. Caution: don't rely on
|
||||
/// the lattice having exactly this number if you get it after this call, as
|
||||
/// it may increase after this-- unless you've already called either
|
||||
/// TerminateDecoding() or InputFinished(), followed by Wait().
|
||||
int32 NumFramesDecoded() const;
|
||||
|
||||
/// Gets the lattice. The output lattice has any acoustic scaling in it
|
||||
/// (which will typically be desirable in an online-decoding context); if you
|
||||
/// want an un-scaled lattice, scale it using ScaleLattice() with the inverse
|
||||
/// of the acoustic weight. "end_of_utterance" will be true if you want the
|
||||
/// final-probs to be included. If this is at the end of the utterance,
|
||||
/// you might want to first call FinalizeDecoding() first; this will make this
|
||||
/// call return faster.
|
||||
/// If no frames have been decoded yet, it will set clat to a lattice with
|
||||
/// a single state that is final and with unit weight (no cost or alignment).
|
||||
/// The output to final_relative_cost (if non-NULL) is a number >= 0 that's
|
||||
/// closer to 0 if a final-state was close to the best-likelihood state
|
||||
/// active on the last frame, at the time we obtained the lattice.
|
||||
void GetLattice(bool end_of_utterance,
|
||||
CompactLattice *clat,
|
||||
BaseFloat *final_relative_cost) const;
|
||||
|
||||
/// Outputs an FST corresponding to the single best path through the current
|
||||
/// lattice. If "use_final_probs" is true AND we reached the final-state of
|
||||
/// the graph then it will include those as final-probs, else it will treat
|
||||
/// all final-probs as one.
|
||||
/// If no frames have been decoded yet, it will set best_path to a lattice with
|
||||
/// a single state that is final and with unit weight (no cost).
|
||||
/// The output to final_relative_cost (if non-NULL) is a number >= 0 that's
|
||||
/// closer to 0 if a final-state were close to the best-likelihood state
|
||||
/// active on the last frame, at the time we got the best path.
|
||||
void GetBestPath(bool end_of_utterance,
|
||||
Lattice *best_path,
|
||||
BaseFloat *final_relative_cost) const;
|
||||
|
||||
/// This function calls EndpointDetected from online-endpoint.h,
|
||||
/// with the required arguments.
|
||||
bool EndpointDetected(const OnlineEndpointConfig &config);
|
||||
|
||||
/// Outputs the adaptation state of the feature pipeline to "adaptation_state". This
|
||||
/// mostly stores stats for iVector estimation, and will generally be called at the
|
||||
/// end of an utterance, assuming it's a scenario where each speaker is seen for
|
||||
/// more than one utterance.
|
||||
/// You may only call this function after either calling TerminateDecoding() or
|
||||
/// InputFinished, and then Wait(). Otherwise it is an error.
|
||||
void GetAdaptationState(OnlineIvectorExtractorAdaptationState *adaptation_state);
|
||||
|
||||
~SingleUtteranceNnet2DecoderThreaded();
|
||||
private:
|
||||
|
||||
// This function will instruct all threads to abort operation as soon as they
|
||||
// can safely do so, by calling SetAbort() in the threads
|
||||
void AbortAllThreads(bool error);
|
||||
|
||||
// This function waits for all the threads that have been spawned, and then
|
||||
// sets the pointers in threads_ to NULL; it is called in the destructor and
|
||||
// from Wait(). If called twice it is not an error.
|
||||
void WaitForAllThreads();
|
||||
|
||||
|
||||
// this function runs the thread that does the feature extraction; ptr_in is
|
||||
// to class SingleUtteranceNnet2DecoderThreaded. Always returns NULL, but in
|
||||
// case of failure, calls ptr_in->AbortAllThreads(true).
|
||||
static void* RunFeatureExtraction(void *ptr_in);
|
||||
// member-function version of RunFeatureExtraction, called by RunFeatureExgtraction.
|
||||
bool RunFeatureExtractionInternal();
|
||||
|
||||
|
||||
// this function runs the thread that does the neural-net evaluation ptr_in is
|
||||
// to class SingleUtteranceNnet2DecoderThreaded. Always returns NULL, but in
|
||||
// case of failure, calls ptr_in->AbortAllThreads(true).
|
||||
static void* RunNnetEvaluation(void *ptr_in);
|
||||
// member-function version of RunNnetEvaluation, called by RunNnetEvaluation.
|
||||
bool RunNnetEvaluationInternal();
|
||||
|
||||
|
||||
// this function runs the thread that does the neural-net evaluation ptr_in is
|
||||
// to class SingleUtteranceNnet2DecoderThreaded. Always returns NULL, but in
|
||||
// case of failure, calls ptr_in->AbortAllThreads(true).
|
||||
static void* RunDecoderSearch(void *ptr_in);
|
||||
// member-function version of RunDecoderSearch, called by RunDecoderSearch.
|
||||
bool RunDecoderSearchInternal();
|
||||
|
||||
|
||||
// Member variables:
|
||||
|
||||
OnlineNnet2DecodingThreadedConfig config_;
|
||||
|
||||
const nnet2::AmNnet &am_nnet_;
|
||||
|
||||
const TransitionModel &tmodel_;
|
||||
|
||||
|
||||
// sampling_rate_ is set the first time AcceptWaveform is called.
|
||||
BaseFloat sampling_rate_;
|
||||
// A record of how many samples have been provided so
|
||||
// far via calls to AcceptWaveform.
|
||||
int64 num_samples_received_;
|
||||
|
||||
// The next two variables are written to by AcceptWaveform from the main
|
||||
// thread, and read by the feature-processing thread; they are guarded by
|
||||
// waveform_synchronizer_. There is no bound on the buffer size here.
|
||||
// Later-arriving data is appended to the vector. When InputFinished() is
|
||||
// called from the main thread, the main thread sets input_finished_ = true.
|
||||
// sampling_rate_ is only needed for checking that it matches the config.
|
||||
bool input_finished_;
|
||||
std::vector< Vector<BaseFloat>* > input_waveform_;
|
||||
ThreadSynchronizer waveform_synchronizer_;
|
||||
|
||||
// feature_pipeline_ is only accessed by the feature-processing thread, and by
|
||||
// the main thread if GetAdaptionState() is called. It is guarded by feature_pipeline_mutex_.
|
||||
OnlineNnet2FeaturePipeline feature_pipeline_;
|
||||
Mutex feature_pipeline_mutex_;
|
||||
|
||||
// The following three variables are guarded by feature_synchronizer_; they are
|
||||
// a buffer of features that is passed between the feature-processing thread
|
||||
// and the neural-net evaluation thread. feature_buffer_start_frame_ is the
|
||||
// frame index of the first feature vector in the buffer. After the
|
||||
// neural-net evaluation thread reads the features, it removes them from the
|
||||
// vector "feature_buffer_" and advances "feature_buffer_start_frame_"
|
||||
// appropriately. The feature-processing thread does not advance
|
||||
// feature_buffer_start_frame_, but appends to feature_buffer_. After
|
||||
// all features are done (because user called InputFinished()), the
|
||||
// feature-processing thread sets feature_buffer_finished_.
|
||||
int32 feature_buffer_start_frame_;
|
||||
bool feature_buffer_finished_;
|
||||
std::vector<Vector<BaseFloat>* > feature_buffer_;
|
||||
ThreadSynchronizer feature_synchronizer_;
|
||||
|
||||
// this Decodable object just stores a matrix of scaled log-likelihoods
|
||||
// obtained by the nnet-evaluation thread. It is produced by the
|
||||
// nnet-evaluation thread and consumed by the decoder-search thread. The
|
||||
// decoding thread sets num_frames_decoded_ so the nnet-evaluation thread
|
||||
// knows which frames of log-likelihoods it can discard. Both of these
|
||||
// variables are guarded by decodable_synchronizer_. Note:
|
||||
// the num_frames_decoded_ may be less than the current number of frames
|
||||
// the decoder has decoded; the decoder thread sets this variable when it
|
||||
// locks this mutex.
|
||||
DecodableMatrixMappedOffset decodable_;
|
||||
int32 num_frames_decoded_;
|
||||
ThreadSynchronizer decodable_synchronizer_;
|
||||
|
||||
// the decoder_ object contains everything related to the graph search.
|
||||
LatticeFasterOnlineDecoder decoder_;
|
||||
// decoder_mutex_ guards the decoder_ object. It is usually held by the decoding
|
||||
// thread (where it is released and re-obtained on each frame), but is obtained
|
||||
// by the main (parent) thread if you call functions like NumFramesDecoded(),
|
||||
// GetLattice() and GetBestPath().
|
||||
Mutex decoder_mutex_;
|
||||
|
||||
// This contains the thread pointers for the feature-extraction,
|
||||
// nnet-evaluation, and decoder-search threads respectively (or NULL if they
|
||||
// have been joined in Wait()).
|
||||
pthread_t threads_[3];
|
||||
|
||||
// This is set to true if AbortAllThreads was called for any reason, including
|
||||
// if someone called TerminateDecoding().
|
||||
bool abort_;
|
||||
|
||||
// This is set to true if any kind of unexpected error is encountered,
|
||||
// including if exceptions are raised in any of the threads. Will normally
|
||||
// be a coding error, malloc failure-- something we should never encounter.
|
||||
bool error_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
/// @} End of "addtogroup onlinedecoding"
|
||||
|
||||
} // namespace kaldi
|
||||
|
||||
|
||||
|
||||
#endif // KALDI_ONLINE2_ONLINE_NNET2_DECODING_THREADED_H_
|
|
@ -33,7 +33,7 @@ SingleUtteranceNnet2Decoder::SingleUtteranceNnet2Decoder(
|
|||
feature_pipeline_(feature_pipeline),
|
||||
tmodel_(tmodel),
|
||||
decodable_(model, tmodel, config.decodable_opts, feature_pipeline),
|
||||
decoder_(fst, config.faster_decoder_opts) {
|
||||
decoder_(fst, config.decoder_opts) {
|
||||
decoder_.InitDecoding();
|
||||
}
|
||||
|
||||
|
@ -56,12 +56,12 @@ void SingleUtteranceNnet2Decoder::GetLattice(bool end_of_utterance,
|
|||
Lattice raw_lat;
|
||||
decoder_.GetRawLattice(&raw_lat, end_of_utterance);
|
||||
|
||||
if (!config_.faster_decoder_opts.determinize_lattice)
|
||||
if (!config_.decoder_opts.determinize_lattice)
|
||||
KALDI_ERR << "--determinize-lattice=false option is not supported at the moment";
|
||||
|
||||
BaseFloat lat_beam = config_.faster_decoder_opts.lattice_beam;
|
||||
BaseFloat lat_beam = config_.decoder_opts.lattice_beam;
|
||||
DeterminizeLatticePhonePrunedWrapper(
|
||||
tmodel_, &raw_lat, lat_beam, clat, config_.faster_decoder_opts.det_opts);
|
||||
tmodel_, &raw_lat, lat_beam, clat, config_.decoder_opts.det_opts);
|
||||
}
|
||||
|
||||
void SingleUtteranceNnet2Decoder::GetBestPath(bool end_of_utterance,
|
||||
|
|
|
@ -49,13 +49,13 @@ namespace kaldi {
|
|||
// here: namely, OnlineNnet2FeaturePipelineConfig and OnlineEndpointConfig.
|
||||
struct OnlineNnet2DecodingConfig {
|
||||
|
||||
LatticeFasterDecoderConfig faster_decoder_opts;
|
||||
LatticeFasterDecoderConfig decoder_opts;
|
||||
nnet2::DecodableNnet2OnlineOptions decodable_opts;
|
||||
|
||||
OnlineNnet2DecodingConfig() { decodable_opts.acoustic_scale = 0.1; }
|
||||
|
||||
void Register(OptionsItf *po) {
|
||||
faster_decoder_opts.Register(po);
|
||||
decoder_opts.Register(po);
|
||||
decodable_opts.Register(po);
|
||||
}
|
||||
};
|
||||
|
@ -77,8 +77,9 @@ class SingleUtteranceNnet2Decoder {
|
|||
/// advance the decoding as far as we can.
|
||||
void AdvanceDecoding();
|
||||
|
||||
/// Finalizes the decoding. Cleanups and prunes remaining tokens, so the final
|
||||
/// result is faster to obtain.
|
||||
/// Finalizes the decoding. Cleans up and prunes remaining tokens, so the
|
||||
/// GetLattice() call will return faster. You must not call this before
|
||||
/// calling (TerminateDecoding() or InputIsFinished()) and then Wait().
|
||||
void FinalizeDecoding();
|
||||
|
||||
int32 NumFramesDecoded() const;
|
||||
|
@ -98,13 +99,6 @@ class SingleUtteranceNnet2Decoder {
|
|||
void GetBestPath(bool end_of_utterance,
|
||||
Lattice *best_path) const;
|
||||
|
||||
/// This function outputs to "final_relative_cost", if non-NULL, a number >= 0
|
||||
/// that will be close to zero if the final-probs were close to the best probs
|
||||
/// active on the final frame. (the output to final_relative_cost is based on
|
||||
/// the first-pass decoding). If it's close to zero (e.g. < 5, as a guess),
|
||||
/// it means you reached the end of the grammar with good probability, which
|
||||
/// can be taken as a good sign that the input was OK.
|
||||
BaseFloat FinalRelativeCost() { return decoder_.FinalRelativeCost(); }
|
||||
|
||||
/// This function calls EndpointDetected from online-endpoint.h,
|
||||
/// with the required arguments.
|
||||
|
|
|
@ -184,4 +184,5 @@ BaseFloat OnlineNnet2FeaturePipelineInfo::FrameShiftInSeconds() const {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
} // namespace kaldi
|
||||
|
|
|
@ -203,9 +203,7 @@ class OnlineNnet2FeaturePipeline: public OnlineFeatureInterface {
|
|||
void AcceptWaveform(BaseFloat sampling_rate,
|
||||
const VectorBase<BaseFloat> &waveform);
|
||||
|
||||
BaseFloat FrameShiftInSeconds() const {
|
||||
return info_.FrameShiftInSeconds();
|
||||
}
|
||||
BaseFloat FrameShiftInSeconds() const { return info_.FrameShiftInSeconds(); }
|
||||
|
||||
/// If you call InputFinished(), it tells the class you won't be providing any
|
||||
/// more waveform. This will help flush out the last few frames of delta or
|
||||
|
|
|
@ -36,7 +36,11 @@ void OnlineTimingStats::Print(bool online){
|
|||
KALDI_LOG << "Timing stats: real-time factor was " << real_time_factor
|
||||
<< " (note: this cannot be less than one.)";
|
||||
KALDI_LOG << "Average delay was " << average_wait << " seconds.";
|
||||
if (idle_percent != 0.0) {
|
||||
// If the user was calling SleepUntil instead of WaitUntil, this will
|
||||
// always be zero; so don't print it in that case.
|
||||
KALDI_LOG << "Percentage of time spent idling was " << idle_percent;
|
||||
}
|
||||
KALDI_LOG << "Longest delay was " << max_delay_ << " seconds for utterance "
|
||||
<< '\'' << max_delay_utt_ << '\'';
|
||||
} else {
|
||||
|
@ -74,6 +78,17 @@ void OnlineTimer::WaitUntil(double cur_utterance_length) {
|
|||
utterance_length_ = cur_utterance_length;
|
||||
}
|
||||
|
||||
void OnlineTimer::SleepUntil(double cur_utterance_length) {
|
||||
KALDI_ASSERT(waited_ == 0 && "Do not mix SleepUntil with WaitUntil.");
|
||||
double elapsed = timer_.Elapsed();
|
||||
|
||||
double to_wait = cur_utterance_length - elapsed;
|
||||
if (to_wait > 0.0) {
|
||||
Sleep(to_wait);
|
||||
}
|
||||
utterance_length_ = cur_utterance_length;
|
||||
}
|
||||
|
||||
double OnlineTimer::Elapsed() {
|
||||
return timer_.Elapsed() + waited_;
|
||||
}
|
||||
|
@ -88,6 +103,9 @@ void OnlineTimer::OutputStats(OnlineTimingStats *stats) {
|
|||
KALDI_WARN << "Negative wait time " << wait_time
|
||||
<< " does not make sense.";
|
||||
}
|
||||
KALDI_VLOG(2) << "Latency " << wait_time << " seconds out of "
|
||||
<< utterance_length_ << ", for utterance "
|
||||
<< utterance_id_;
|
||||
|
||||
stats->num_utts_++;
|
||||
stats->total_audio_ += utterance_length_;
|
||||
|
|
|
@ -50,8 +50,10 @@ class OnlineTimingStats {
|
|||
int32 num_utts_;
|
||||
// all times are in seconds.
|
||||
double total_audio_; // total time of audio.
|
||||
double total_time_taken_;
|
||||
double total_time_waited_; // total time in wait() state.
|
||||
double total_time_taken_; // total time spent processing the audio.
|
||||
double total_time_waited_; // total time we pretended to wait (but just
|
||||
// increased the waited_ counter)... zero if you
|
||||
// called SleepUntil instead of WaitUntil().
|
||||
double max_delay_; // maximum delay at utterance end.
|
||||
std::string max_delay_utt_;
|
||||
};
|
||||
|
@ -65,7 +67,8 @@ class OnlineTimingStats {
|
|||
/// available in a real-time application-- e.g. say we need to process a chunk
|
||||
/// that ends half a second into the utterance, we would sleep until half a
|
||||
/// second had elapsed since the start of the utterance. In this code we
|
||||
/// don't actually sleep; we simulate the effect of sleeping by just incrementing
|
||||
/// have the option to not actually sleep: we can simulate the effect of
|
||||
/// sleeping by just incrementing
|
||||
/// a variable that says how long we would have slept; and we add this to
|
||||
/// wall-clock times obtained from Timer::Elapsed().
|
||||
/// The usage of this class will be something like as follows:
|
||||
|
@ -86,8 +89,14 @@ class OnlineTimer {
|
|||
public:
|
||||
OnlineTimer(const std::string &utterance_id);
|
||||
|
||||
/// The call to WaitUntil(t) simulates the effect of waiting
|
||||
/// until t seconds after this object was initialized.
|
||||
/// The call to SleepUntil(t) will sleep until cur_utterance_length seconds
|
||||
/// after this object was initialized, or return immediately if we've
|
||||
/// already passed that time.
|
||||
void SleepUntil(double cur_utterance_length);
|
||||
|
||||
/// The call to WaitUntil(t) simulates the effect of sleeping until
|
||||
/// cur_utterance_length seconds after this object was initialized;
|
||||
/// but instead of actually sleeping, it increases a counter.
|
||||
void WaitUntil(double cur_utterance_length);
|
||||
|
||||
/// This call, which should be made after decoding is done,
|
||||
|
|
|
@ -10,7 +10,7 @@ BINFILES = online2-wav-gmm-latgen-faster apply-cmvn-online \
|
|||
extend-wav-with-silence compress-uncompress-speex \
|
||||
online2-wav-nnet2-latgen-faster ivector-extract-online2 \
|
||||
online2-wav-dump-features ivector-randomize \
|
||||
online2-wav-nnet2-am-compute
|
||||
online2-wav-nnet2-am-compute online2-wav-nnet2-latgen-threaded
|
||||
|
||||
OBJFILES =
|
||||
|
||||
|
|
|
@ -197,7 +197,9 @@ void FindQuietestSegment(const Vector<BaseFloat> &wav_in,
|
|||
start += seg_shift;
|
||||
}
|
||||
|
||||
KALDI_ASSERT(min_energy > 0.0);
|
||||
if (min_energy == 0.0) {
|
||||
KALDI_WARN << "Zero energy silence being used.";
|
||||
}
|
||||
*wav_sil = wav_min_energy;
|
||||
}
|
||||
|
||||
|
|
|
@ -88,7 +88,8 @@ int main(int argc, char *argv[]) {
|
|||
"<spk2utt-rspecifier> <wav-rspecifier> <lattice-wspecifier>\n"
|
||||
"The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n"
|
||||
"you want to decode utterance by utterance.\n"
|
||||
"See egs/rm/s5/local/run_online_decoding_nnet2.sh for example\n";
|
||||
"See egs/rm/s5/local/run_online_decoding_nnet2.sh for example\n"
|
||||
"See also online2-wav-nnet2-latgen-threaded\n";
|
||||
|
||||
ParseOptions po(usage);
|
||||
|
||||
|
|
|
@ -0,0 +1,277 @@
|
|||
// onlinebin/online2-wav-nnet2-latgen-thread.cc
|
||||
|
||||
// Copyright 2014-2015 Johns Hopkins University (author: Daniel Povey)
|
||||
|
||||
// See ../../COPYING for clarification regarding multiple authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "feat/wave-reader.h"
|
||||
#include "online2/online-nnet2-decoding-threaded.h"
|
||||
#include "online2/onlinebin-util.h"
|
||||
#include "online2/online-timing.h"
|
||||
#include "online2/online-endpoint.h"
|
||||
#include "fstext/fstext-lib.h"
|
||||
#include "lat/lattice-functions.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
void GetDiagnosticsAndPrintOutput(const std::string &utt,
|
||||
const fst::SymbolTable *word_syms,
|
||||
const CompactLattice &clat,
|
||||
int64 *tot_num_frames,
|
||||
double *tot_like) {
|
||||
if (clat.NumStates() == 0) {
|
||||
KALDI_WARN << "Empty lattice.";
|
||||
return;
|
||||
}
|
||||
CompactLattice best_path_clat;
|
||||
CompactLatticeShortestPath(clat, &best_path_clat);
|
||||
|
||||
Lattice best_path_lat;
|
||||
ConvertLattice(best_path_clat, &best_path_lat);
|
||||
|
||||
double likelihood;
|
||||
LatticeWeight weight;
|
||||
int32 num_frames;
|
||||
std::vector<int32> alignment;
|
||||
std::vector<int32> words;
|
||||
GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight);
|
||||
num_frames = alignment.size();
|
||||
likelihood = -(weight.Value1() + weight.Value2());
|
||||
*tot_num_frames += num_frames;
|
||||
*tot_like += likelihood;
|
||||
KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is "
|
||||
<< (likelihood / num_frames) << " over " << num_frames
|
||||
<< " frames.";
|
||||
|
||||
if (word_syms != NULL) {
|
||||
std::cerr << utt << ' ';
|
||||
for (size_t i = 0; i < words.size(); i++) {
|
||||
std::string s = word_syms->Find(words[i]);
|
||||
if (s == "")
|
||||
KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
|
||||
std::cerr << s << ' ';
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
try {
|
||||
using namespace kaldi;
|
||||
using namespace fst;
|
||||
|
||||
typedef kaldi::int32 int32;
|
||||
typedef kaldi::int64 int64;
|
||||
|
||||
const char *usage =
|
||||
"Reads in wav file(s) and simulates online decoding with neural nets\n"
|
||||
"(nnet2 setup), with optional iVector-based speaker adaptation and\n"
|
||||
"optional endpointing. This version uses multiple threads for decoding.\n"
|
||||
"Note: some configuration values and inputs are set via config files\n"
|
||||
"whose filenames are passed as options\n"
|
||||
"\n"
|
||||
"Usage: online2-wav-nnet2-latgen-threaded [options] <nnet2-in> <fst-in> "
|
||||
"<spk2utt-rspecifier> <wav-rspecifier> <lattice-wspecifier>\n"
|
||||
"The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n"
|
||||
"you want to decode utterance by utterance.\n"
|
||||
"See egs/rm/s5/local/run_online_decoding_nnet2.sh for example\n"
|
||||
"See also online2-wav-nnet2-latgen-faster\n";
|
||||
|
||||
ParseOptions po(usage);
|
||||
|
||||
std::string word_syms_rxfilename;
|
||||
|
||||
OnlineEndpointConfig endpoint_config;
|
||||
|
||||
// feature_config includes configuration for the iVector adaptation,
|
||||
// as well as the basic features.
|
||||
OnlineNnet2FeaturePipelineConfig feature_config;
|
||||
OnlineNnet2DecodingThreadedConfig nnet2_decoding_config;
|
||||
|
||||
BaseFloat chunk_length_secs = 0.05;
|
||||
bool do_endpointing = false;
|
||||
bool modify_ivector_config = false;
|
||||
|
||||
po.Register("chunk-length", &chunk_length_secs,
|
||||
"Length of chunk size in seconds, that we provide each time to the "
|
||||
"decoder. The actual chunk sizes it processes for various stages "
|
||||
"of decoding are dynamically determinated, and unrelated to this");
|
||||
po.Register("word-symbol-table", &word_syms_rxfilename,
|
||||
"Symbol table for words [for debug output]");
|
||||
po.Register("do-endpointing", &do_endpointing,
|
||||
"If true, apply endpoint detection");
|
||||
po.Register("modify-ivector-config", &modify_ivector_config,
|
||||
"If true, modifies the iVector configuration from the config files "
|
||||
"by setting --use-most-recent-ivector=true and --greedy-ivector-extractor=true. "
|
||||
"This will give the best possible results, but the results may become dependent "
|
||||
"on the speed of your machine (slower machine -> better results). Compare "
|
||||
"to the --online option in online2-wav-nnet-latgen-faster");
|
||||
|
||||
feature_config.Register(&po);
|
||||
nnet2_decoding_config.Register(&po);
|
||||
endpoint_config.Register(&po);
|
||||
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (po.NumArgs() != 5) {
|
||||
po.PrintUsage();
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string nnet2_rxfilename = po.GetArg(1),
|
||||
fst_rxfilename = po.GetArg(2),
|
||||
spk2utt_rspecifier = po.GetArg(3),
|
||||
wav_rspecifier = po.GetArg(4),
|
||||
clat_wspecifier = po.GetArg(5);
|
||||
|
||||
OnlineNnet2FeaturePipelineInfo feature_info(feature_config);
|
||||
|
||||
if (modify_ivector_config) {
|
||||
feature_info.ivector_extractor_info.use_most_recent_ivector = true;
|
||||
feature_info.ivector_extractor_info.greedy_ivector_extractor = true;
|
||||
}
|
||||
|
||||
TransitionModel trans_model;
|
||||
nnet2::AmNnet am_nnet;
|
||||
{
|
||||
bool binary;
|
||||
Input ki(nnet2_rxfilename, &binary);
|
||||
trans_model.Read(ki.Stream(), binary);
|
||||
am_nnet.Read(ki.Stream(), binary);
|
||||
}
|
||||
|
||||
fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldi(fst_rxfilename);
|
||||
|
||||
fst::SymbolTable *word_syms = NULL;
|
||||
if (word_syms_rxfilename != "")
|
||||
if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename)))
|
||||
KALDI_ERR << "Could not read symbol table from file "
|
||||
<< word_syms_rxfilename;
|
||||
|
||||
int32 num_done = 0, num_err = 0;
|
||||
double tot_like = 0.0;
|
||||
int64 num_frames = 0;
|
||||
|
||||
SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
|
||||
RandomAccessTableReader<WaveHolder> wav_reader(wav_rspecifier);
|
||||
CompactLatticeWriter clat_writer(clat_wspecifier);
|
||||
|
||||
OnlineTimingStats timing_stats;
|
||||
|
||||
for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
|
||||
std::string spk = spk2utt_reader.Key();
|
||||
const std::vector<std::string> &uttlist = spk2utt_reader.Value();
|
||||
OnlineIvectorExtractorAdaptationState adaptation_state(
|
||||
feature_info.ivector_extractor_info);
|
||||
for (size_t i = 0; i < uttlist.size(); i++) {
|
||||
std::string utt = uttlist[i];
|
||||
if (!wav_reader.HasKey(utt)) {
|
||||
KALDI_WARN << "Did not find audio for utterance " << utt;
|
||||
num_err++;
|
||||
continue;
|
||||
}
|
||||
const WaveData &wave_data = wav_reader.Value(utt);
|
||||
// get the data for channel zero (if the signal is not mono, we only
|
||||
// take the first channel).
|
||||
SubVector<BaseFloat> data(wave_data.Data(), 0);
|
||||
|
||||
|
||||
SingleUtteranceNnet2DecoderThreaded decoder(
|
||||
nnet2_decoding_config, trans_model, am_nnet,
|
||||
*decode_fst, feature_info, adaptation_state);
|
||||
|
||||
OnlineTimer decoding_timer(utt);
|
||||
|
||||
BaseFloat samp_freq = wave_data.SampFreq();
|
||||
int32 chunk_length;
|
||||
KALDI_ASSERT(chunk_length_secs > 0);
|
||||
chunk_length = int32(samp_freq * chunk_length_secs);
|
||||
if (chunk_length == 0) chunk_length = 1;
|
||||
|
||||
int32 samp_offset = 0;
|
||||
while (samp_offset < data.Dim()) {
|
||||
int32 samp_remaining = data.Dim() - samp_offset;
|
||||
int32 num_samp = chunk_length < samp_remaining ? chunk_length
|
||||
: samp_remaining;
|
||||
|
||||
SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
|
||||
decoder.AcceptWaveform(samp_freq, wave_part);
|
||||
|
||||
samp_offset += num_samp;
|
||||
// Note: the next call may actually call sleep().
|
||||
decoding_timer.SleepUntil(samp_offset / samp_freq);
|
||||
if (samp_offset == data.Dim()) {
|
||||
// no more input. flush out last frames
|
||||
decoder.InputFinished();
|
||||
}
|
||||
|
||||
if (do_endpointing && decoder.EndpointDetected(endpoint_config)) {
|
||||
decoder.TerminateDecoding();
|
||||
break;
|
||||
}
|
||||
}
|
||||
Timer timer;
|
||||
decoder.Wait();
|
||||
KALDI_VLOG(1) << "Waited " << timer.Elapsed() << " seconds for decoder to "
|
||||
<< "finish after giving it last chunk.";
|
||||
decoder.FinalizeDecoding();
|
||||
|
||||
CompactLattice clat;
|
||||
bool end_of_utterance = true;
|
||||
decoder.GetLattice(end_of_utterance, &clat, NULL);
|
||||
|
||||
GetDiagnosticsAndPrintOutput(utt, word_syms, clat,
|
||||
&num_frames, &tot_like);
|
||||
|
||||
decoding_timer.OutputStats(&timing_stats);
|
||||
|
||||
// In an application you might avoid updating the adaptation state if
|
||||
// you felt the utterance had low confidence. See lat/confidence.h
|
||||
decoder.GetAdaptationState(&adaptation_state);
|
||||
|
||||
// we want to output the lattice with un-scaled acoustics.
|
||||
BaseFloat inv_acoustic_scale =
|
||||
1.0 / nnet2_decoding_config.acoustic_scale;
|
||||
ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat);
|
||||
|
||||
KALDI_VLOG(1) << "Adding the various end-of-utterance tasks took the "
|
||||
<< "total latency to " << timer.Elapsed() << " seconds.";
|
||||
|
||||
clat_writer.Write(utt, clat);
|
||||
KALDI_LOG << "Decoded utterance " << utt;
|
||||
|
||||
|
||||
|
||||
num_done++;
|
||||
}
|
||||
}
|
||||
bool online = true;
|
||||
timing_stats.Print(online);
|
||||
|
||||
KALDI_LOG << "Decoded " << num_done << " utterances, "
|
||||
<< num_err << " with errors.";
|
||||
KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames)
|
||||
<< " per frame over " << num_frames << " frames.";
|
||||
delete decode_fst;
|
||||
delete word_syms; // will delete if non-NULL.
|
||||
return (num_done != 0 ? 0 : 1);
|
||||
} catch(const std::exception& e) {
|
||||
std::cerr << e.what();
|
||||
return -1;
|
||||
}
|
||||
} // main()
|
|
@ -46,7 +46,8 @@ Mutex::~Mutex() {
|
|||
<< "a known issue that affects Haswell processors, see "
|
||||
<< "https://sourceware.org/bugzilla/show_bug.cgi?id=16657 "
|
||||
<< "If your processor is not Haswell and you see this message, "
|
||||
<< "it could be a bug in Kaldi.";
|
||||
<< "it could be a bug in Kaldi. However it could be that "
|
||||
<< "multi-threaded code terminated messily.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2010,6 +2010,7 @@ template<class Holder> class RandomAccessTableReaderUnsortedArchiveImpl:
|
|||
|
||||
if (!pr.second) { // Was not inserted-- previous element w/ same key
|
||||
delete holder_; // map was not changed, no ownership transferred.
|
||||
holder_ = NULL;
|
||||
KALDI_ERR << "Error in RandomAccessTableReader: duplicate key "
|
||||
<< cur_key_ << " in archive " << archive_rxfilename_;
|
||||
}
|
||||
|
|
Загрузка…
Ссылка в новой задаче