trunk: add multi-threaded online-nnet2 decoding program, online2-wav-nnet2-latgen-threaded, which does decoding and nnet evaluation in different threads. Usage is otherwise similar to online2-wav-nnet2-latgen-faster.

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@4844 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
This commit is contained in:
Dan Povey 2015-01-31 01:01:00 +00:00
Родитель 41a8f9b3cb
Коммит 59736ea848
34 изменённых файлов: 1839 добавлений и 88 удалений

Просмотреть файл

@ -148,3 +148,26 @@ if [ $stage -le 13 ]; then
fi
exit 0;
###### Comment out the "exit 0" above to run the multi-threaded decoding. #####
if [ $stage -le 14 ]; then
# Demonstrate the multi-threaded decoding.
# put back the pp
test=dev_clean
steps/online/nnet2/decode.sh --threaded true \
--config conf/decode.config --cmd "$decode_cmd" --nj 30 \
--per-utt true exp/tri6b/graph_pp_tgsmall data/$test \
${dir}_online/decode_pp_${test}_tgsmall_utt_threaded || exit 1;
fi
if [ $stage -le 15 ]; then
# Demonstrate the multi-threaded decoding with endpointing.
# put back the pp
test=dev_clean
steps/online/nnet2/decode.sh --threaded true --do-endpointing true \
--config conf/decode.config --cmd "$decode_cmd" --nj 30 \
--per-utt true exp/tri6b/graph_pp_tgsmall data/$test \
${dir}_online/decode_pp_${test}_tgsmall_utt_threaded_ep || exit 1;
fi
exit 0;

Просмотреть файл

@ -151,10 +151,11 @@ utils/mkgraph.sh data/lang_test_4g exp/tri3b exp/tri3b/graph_4g || exit 1;
steps/decode_fmllr.sh --cmd "$decode_cmd" --nj 7 \
exp/tri3b/graph_4g data/test1k exp/tri3b/decode_4g_test1k || exit 1;
# Train RNN for reranking
local/sprak_train_rnnlms.sh data/local/dict data/dev/transcripts.uniq data/local/rnnlms/g_c380_d1k_h100_v130k
# Consumes a lot of memory! Do not run in parallel
local/sprak_run_rnnlms_tri3b.sh data/lang_test_3g data/local/rnnlms/g_c380_d1k_h100_v130k data/test1k exp/tri3b/decode_3g_test1k
# This is commented out for now as it's not important for the main recipe.
## Train RNN for reranking
#local/sprak_train_rnnlms.sh data/local/dict data/dev/transcripts.uniq data/local/rnnlms/g_c380_d1k_h100_v130k
## Consumes a lot of memory! Do not run in parallel
#local/sprak_run_rnnlms_tri3b.sh data/lang_test_3g data/local/rnnlms/g_c380_d1k_h100_v130k data/test1k exp/tri3b/decode_3g_test1k
# From 3b system

Просмотреть файл

@ -8,12 +8,14 @@ stage=0
nj=4
cmd=run.pl
max_active=7000
threaded=false
modify_ivector_config=false # only relevant to threaded decoder.
beam=15.0
lattice_beam=6.0
acwt=0.1 # note: only really affects adaptation and pruning (scoring is on
# lattices).
per_utt=false
online=true
online=true # only relevant to non-threaded decoder.
do_endpointing=false
do_speex_compressing=false
scoring_opts=
@ -92,9 +94,23 @@ if $do_endpointing; then
wav_rspecifier="$wav_rspecifier extend-wav-with-silence ark:- ark:- |"
fi
if $threaded; then
decoder=online2-wav-nnet2-latgen-threaded
# note: the decoder actually uses 4 threads, but the average usage will normally
# be more like 2.
parallel_opts="--num-threads 2"
opts="--modify-ivector-config=$modify_ivector_config --verbose=1"
else
decoder=online2-wav-nnet2-latgen-faster
parallel_opts=
opts="--online=$online"
fi
if [ $stage -le 0 ]; then
$cmd JOB=1:$nj $dir/log/decode.JOB.log \
online2-wav-nnet2-latgen-faster --online=$online --do-endpointing=$do_endpointing \
$cmd $parallel_opts JOB=1:$nj $dir/log/decode.JOB.log \
$decoder $opts --do-endpointing=$do_endpointing \
--config=$srcdir/conf/online_nnet2_decoding.conf \
--max-active=$max_active --beam=$beam --lattice-beam=$lattice_beam \
--acoustic-scale=$acwt --word-symbol-table=$graphdir/words.txt \

Просмотреть файл

@ -3,7 +3,7 @@ all:
include ../kaldi.mk
TESTFILES = kaldi-math-test io-funcs-test kaldi-error-test
TESTFILES = kaldi-math-test io-funcs-test kaldi-error-test timer-test
OBJFILES = kaldi-math.o kaldi-error.o io-funcs.o kaldi-utils.o

Просмотреть файл

@ -128,14 +128,19 @@ std::string KaldiGetStackTrace() {
void KaldiAssertFailure_(const char *func, const char *file,
int32 line, const char *cond_str) {
std::cerr << "KALDI_ASSERT: at " << GetProgramName() << func << ':'
std::ostringstream ss;
ss << "KALDI_ASSERT: at " << GetProgramName() << func << ':'
<< GetShortFileName(file)
<< ':' << line << ", failed: " << cond_str << '\n';
#ifdef HAVE_EXECINFO_H
std::cerr << "Stack trace is:\n" << KaldiGetStackTrace();
ss << "Stack trace is:\n" << KaldiGetStackTrace();
#endif
std::cerr << ss.str();
std::cerr.flush();
abort(); // Will later throw instead if needed.
// We used to call abort() here, but switch to throwing an exception
// (like KALDI_ERR) because it's easier to deal with in multi-threaded
// code.
throw std::runtime_error(ss.str());
}

Просмотреть файл

@ -19,6 +19,15 @@
#include <string>
#include "base/kaldi-common.h"
#ifdef _WIN32_WINNT_WIN8
#include <Synchapi.h>
#elif defined (_WIN32) || defined(_MSC_VER) || defined(MINGW)
#include <Windows.h>
#else
#include <unistd.h>
#endif
namespace kaldi {
std::string CharToString(const char &c) {
@ -30,4 +39,12 @@ std::string CharToString(const char &c) {
return (std::string) buf;
}
void Sleep(float seconds) {
#if defined(_MSC_VER) || defined(MINGW)
::Sleep(static_cast<int>(seconds * 1000.0));
#else
usleep(static_cast<int>(seconds * 1000000.0));
#endif
}
} // end namespace kaldi

Просмотреть файл

@ -78,6 +78,10 @@ inline int MachineIsLittleEndian() {
return (*reinterpret_cast<char*>(&check) != 0);
}
// This function kaldi::Sleep() provides a portable way to sleep for a possibly fractional
// number of seconds. On Windows it's only accurate to microseconds.
void Sleep(float seconds);
}
#define KALDI_SWAP8(a) { \

Просмотреть файл

@ -1,6 +1,7 @@
// base/timer-test.cc
// Copyright 2009-2011 Microsoft Corporation
// 2014 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
@ -19,28 +20,27 @@
#include "base/timer.h"
#include "base/kaldi-common.h"
#include "base/kaldi-utils.h"
namespace kaldi {
void TimerTest() {
float time_secs = 0.025 * (rand() % 10);
std::cout << "target is " << time_secs << "\n";
Timer timer;
#if defined(_MSC_VER) || defined(MINGW)
Sleep(1000);
#else
sleep(1);
#endif
Sleep(time_secs);
BaseFloat f = timer.Elapsed();
std::cout << "time is " << f;
KALDI_ASSERT(fabs(1.0 - f) < 0.1);
std::cout << "time is " << f << std::endl;
if (fabs(time_secs - f) > 0.05)
KALDI_ERR << "Timer fail: waited " << f << " seconds instead of "
<< time_secs << " secs.";
}
}
int main() {
for (int i = 0; i < 4; i++)
kaldi::TimerTest();
}

Просмотреть файл

@ -19,9 +19,11 @@
#ifndef KALDI_BASE_TIMER_H_
#define KALDI_BASE_TIMER_H_
#if defined(_MSC_VER) || defined(MINGW)
#include "base/kaldi-utils.h"
// Note: Sleep(float secs) is included in base/kaldi-utils.h.
#if defined(_MSC_VER) || defined(MINGW)
namespace kaldi
{

Просмотреть файл

@ -83,6 +83,88 @@ class DecodableMatrixScaledMapped: public DecodableInterface {
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixScaledMapped);
};
/**
This decodable class returns log-likes stored in a matrix; it supports
repeatedly writing to the matrix and setting a time-offset representing the
frame-index of the first row of the matrix. It's intended for use in
multi-threaded decoding; mutex and semaphores are not included. External
code will call SetLoglikes() each time more log-likelihods are available.
If you try to access a log-likelihood that's no longer available because
the frame index is less than the current offset, it is of course an error.
*/
class DecodableMatrixMappedOffset: public DecodableInterface {
public:
DecodableMatrixMappedOffset(const TransitionModel &tm):
trans_model_(tm), frame_offset_(0), input_is_finished_(false) { }
virtual int32 NumFramesReady() { return frame_offset_ + loglikes_.NumRows(); }
// this is not part of the generic Decodable interface.
int32 FirstAvailableFrame() { return frame_offset_; }
// This function is destructive of the input "loglikes" because it may
// under some circumstances do a shallow copy using Swap(). This function
// appends loglikes to any existing likelihoods you've previously supplied.
// frames_to_discard, if nonzero, will discard that number of previously
// available frames, from the left, advancing FirstAvailableFrame() by
// a number equal to frames_to_discard. You should only set frames_to_discard
// to nonzero if you know your decoder won't want to access the loglikes
// for older frames.
void AcceptLoglikes(Matrix<BaseFloat> *loglikes,
int32 frames_to_discard) {
if (loglikes->NumRows() == 0) return;
KALDI_ASSERT(loglikes->NumCols() == trans_model_.NumPdfs());
KALDI_ASSERT(frames_to_discard <= loglikes_.NumRows() &&
frames_to_discard >= 0);
if (frames_to_discard == loglikes_.NumRows()) {
loglikes_.Swap(loglikes);
loglikes->Resize(0, 0);
} else {
int32 old_rows_kept = loglikes_.NumRows() - frames_to_discard,
new_num_rows = old_rows_kept + loglikes->NumRows();
Matrix<BaseFloat> new_loglikes(new_num_rows, loglikes->NumCols());
new_loglikes.RowRange(0, old_rows_kept).CopyFromMat(
loglikes_.RowRange(frames_to_discard, old_rows_kept));
new_loglikes.RowRange(old_rows_kept, loglikes->NumRows()).CopyFromMat(
*loglikes);
loglikes_.Swap(&new_loglikes);
}
frame_offset_ += frames_to_discard;
}
void InputIsFinished() { input_is_finished_ = true; }
virtual int32 NumFramesReady() const {
return loglikes_.NumRows() + frame_offset_;
}
virtual bool IsLastFrame(int32 frame) const {
KALDI_ASSERT(frame < NumFramesReady());
return (frame == NumFramesReady() - 1 && input_is_finished_);
}
virtual BaseFloat LogLikelihood(int32 frame, int32 tid) {
int32 index = frame - frame_offset_;
KALDI_ASSERT(index >= 0 && index < loglikes_.NumRows());
return loglikes_(index, trans_model_.TransitionIdToPdf(tid));
}
virtual int32 NumIndices() const { return trans_model_.NumTransitionIds(); }
// nothing special to do in destructor.
virtual ~DecodableMatrixMappedOffset() { }
private:
const TransitionModel &trans_model_; // for tid to pdf mapping
Matrix<BaseFloat> loglikes_;
int32 frame_offset_;
bool input_is_finished_;
KALDI_DISALLOW_COPY_AND_ASSIGN(DecodableMatrixMappedOffset);
};
class DecodableMatrixScaled: public DecodableInterface {
public:

Просмотреть файл

@ -19,11 +19,9 @@
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
// Note on svn: this file is "upstream" from lattice-faster-online-decoder.h,
// and changes in this file should be merged into
// lattice-faster-online-decoder.h, after committing the changes to this file,
// using the command
// svn merge ^/sandbox/online/src/decoder/lattice-faster-decoder.h lattice-faster-online-decoder.h
// Note: this file is "upstream" from lattice-faster-online-decoder.h,
// and changes in this file should be made to lattice-faster-online-decoder.h,
// if applicable.
#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_
@ -165,8 +163,8 @@ class LatticeFasterDecoder {
bool use_final_probs = true) const;
/// InitDecoding initializes the decoding, and should only be used if you
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need
/// to call this. You can call InitDecoding if you have already decoded an
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need to
/// call this. You can also call InitDecoding if you have already decoded an
/// utterance and want to start with a new utterance.
void InitDecoding();
@ -408,6 +406,7 @@ class LatticeFasterDecoder {
void ClearActiveTokens();
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoder);
};

Просмотреть файл

@ -148,8 +148,8 @@ class LatticeFasterOnlineDecoder {
/// InitDecoding initializes the decoding, and should only be used if you
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need
/// to call this. You can call InitDecoding if you have already decoded an
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need to
/// call this. You can also call InitDecoding if you have already decoded an
/// utterance and want to start with a new utterance.
void InitDecoding();
@ -398,6 +398,8 @@ class LatticeFasterOnlineDecoder {
void ClearActiveTokens();
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterOnlineDecoder);
};

Просмотреть файл

@ -604,7 +604,7 @@ void OnlineIvectorEstimationStats::GetIvector(
ivector->SetZero();
(*ivector)(0) = prior_offset_;
}
KALDI_VLOG(3) << "Objective function improvement from estimating the "
KALDI_VLOG(4) << "Objective function improvement from estimating the "
<< "iVector (vs. default value) is "
<< ObjfChange(*ivector);
}

Просмотреть файл

@ -20,7 +20,7 @@ OBJFILES = nnet-component.o nnet-nnet.o train-nnet.o train-nnet-ensemble.o nnet-
get-feature-transform.o widen-nnet.o nnet-precondition-online.o \
nnet-example-functions.o nnet-compute-discriminative.o \
nnet-compute-discriminative-parallel.o online-nnet2-decodable.o \
train-nnet-perturbed.o
train-nnet-perturbed.o nnet-compute-online.o
LIBNAME = kaldi-nnet2

Просмотреть файл

@ -20,16 +20,6 @@
#include "nnet2/nnet-component.h"
#include "util/common-utils.h"
#ifdef _WIN32_WINNT_WIN8
#include <Synchapi.h>
#define sleep Sleep
#elif _WIN32
#include <Windows.h>
#define sleep Sleep
#else
#include <unistd.h> // for sleep().
#endif
namespace kaldi {
namespace nnet2 {
@ -332,7 +322,7 @@ void UnitTestAffineComponent() {
mat.SetRandn();
mat.Scale(param_stddev);
WriteKaldiObject(mat, "tmpf", true);
sleep(1);
Sleep(0.5);
component.Init(learning_rate, "tmpf");
unlink("tmpf");
}
@ -433,7 +423,7 @@ void UnitTestAffineComponentPreconditioned() {
mat.SetRandn();
mat.Scale(param_stddev);
WriteKaldiObject(mat, "tmpf", true);
sleep(1);
Sleep(0.5);
component.Init(learning_rate, alpha, max_change, "tmpf");
unlink("tmpf");
}
@ -467,7 +457,7 @@ void UnitTestAffineComponentPreconditionedOnline() {
mat.SetRandn();
mat.Scale(param_stddev);
WriteKaldiObject(mat, "tmpf", true);
sleep(1);
Sleep(0.5);
component.Init(learning_rate, rank_in, rank_out,
update_period, num_samples_history, alpha,
max_change_per_sample, "tmpf");

Просмотреть файл

@ -0,0 +1,131 @@
// nnet2/nnet-compute-online.cc
// Copyright 2014 Johns Hopkins University (author: Daniel Povey)
// Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "nnet2/nnet-compute-online.h"
namespace kaldi {
namespace nnet2 {
NnetOnlineComputer::NnetOnlineComputer(const Nnet &nnet, bool pad_input)
: nnet_(nnet), pad_input_(pad_input),
is_first_chunk_(true), finished_(false) {
data_.resize(nnet_.NumComponents() + 1);
unused_input_.Resize(0, 0);
// WARNING: we always pad in this dirty implementation.
pad_input_ = true;
}
void NnetOnlineComputer::Compute(const CuMatrixBase<BaseFloat> &input,
CuMatrix<BaseFloat> *output) {
KALDI_ASSERT(output != NULL);
KALDI_ASSERT(!finished_);
int32 dim = input.NumCols();
// If input is empty, we also set output to zero size.
if (input.NumRows() == 0) {
output->Resize(0, 0);
return;
}
// Check if feature dimension matches that required by the neural network.
if (dim != nnet_.InputDim()) {
KALDI_ERR << "Feature dimension is " << dim << ", but network expects "
<< nnet_.InputDim();
}
// Pad at the start of the file if necessary.
if (pad_input_ && is_first_chunk_) {
KALDI_ASSERT(unused_input_.NumRows() == 0);
unused_input_.Resize(nnet_.LeftContext(), dim);
for (int32 i = 0; i < nnet_.LeftContext(); i++)
unused_input_.Row(i).CopyFromVec(input.Row(0));
is_first_chunk_ = false;
}
int32 num_rows = unused_input_.NumRows() + input.NumRows();
// Splice unused_input_ and input.
CuMatrix<BaseFloat> &input_data(data_[0]);
input_data.Resize(num_rows, dim);
input_data.Range(0, unused_input_.NumRows(),
0, dim).CopyFromMat(unused_input_);
input_data.Range(unused_input_.NumRows(), input.NumRows(),
0, dim).CopyFromMat(input);
if (num_rows > nnet_.LeftContext() + nnet_.RightContext()) {
nnet_.ComputeChunkInfo(num_rows, 1, &chunk_info_);
Propagate();
*output = data_.back();
} else {
output->Resize(0, 0);
}
// Now store the part of input that will be needed in the next call of
// Compute().
int32 unused_num_rows = nnet_.LeftContext() + nnet_.RightContext();
if (unused_num_rows > num_rows) { unused_num_rows = num_rows; }
unused_input_.Resize(unused_num_rows, dim);
unused_input_.CopyFromMat(input_data.Range(num_rows - unused_num_rows,
unused_num_rows, 0, dim));
}
void NnetOnlineComputer::Flush(CuMatrix<BaseFloat> *output) {
KALDI_ASSERT(!finished_);
int32 right_context = (pad_input_ ? nnet_.RightContext() : 0);
int32 num_rows = unused_input_.NumRows() + right_context;
// If no frame needs to be computed, set output to empty and return.
if (num_rows == 0 || unused_input_.NumRows() == 0) {
output->Resize(0, 0);
finished_ = true;
return;
}
int32 dim = unused_input_.NumCols();
CuMatrix<BaseFloat> &input_data(data_[0]);
input_data.Resize(num_rows, dim);
input_data.Range(0, unused_input_.NumRows(),
0, dim).CopyFromMat(unused_input_);
if (right_context > 0) {
int32 last_row = unused_input_.NumRows() - 1;
for (int32 i = 0; i < right_context; i++)
input_data.Row(num_rows - i - 1).CopyFromVec(unused_input_.Row(last_row));
}
nnet_.ComputeChunkInfo(num_rows, 1, &chunk_info_);
Propagate();
*output = data_.back();
finished_ = true;
}
void NnetOnlineComputer::Propagate() {
for (int32 c = 0; c < nnet_.NumComponents(); ++c) {
const Component &component = nnet_.GetComponent(c);
CuMatrix<BaseFloat> &input_data = data_[c], &output_data = data_[c + 1];
component.Propagate(chunk_info_[c], chunk_info_[c + 1],
input_data, &output_data);
}
}
} // namespace nnet2
} // namespace kaldi

Просмотреть файл

@ -0,0 +1,99 @@
// nnet2/nnet-compute-online.h
// Copyright 2014 Johns Hopkins University (author: Daniel Povey)
// Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_NNET2_NNET_COMPUTE_ONLINE_H_
#define KALDI_NNET2_NNET_COMPUTE_ONLINE_H_
#include "nnet2/nnet-nnet.h"
namespace kaldi {
namespace nnet2 {
/* This header provides functionality for doing forward computation in a situation
where you want to start from the beginning of a file and progressively compute
more, while re-using the hidden parts that (due to context) may be shared.
(note: this sharing is more of an issue in multi-splice networks where there is
splicing over time in the middle layers of the network).
Note: this doesn't do the final taking-the-log and correcting for the prior.
The current implementation is just an inefficient placeholder implementation;
later we'll modify it to properly use previously computed activations.
*/
class NnetOnlineComputer {
public:
// All the inputs and outputs are of type CuMatrix, in case we're doing the
// computation on the GPU (of course, if there is no GPU, it backs off to
// using the CPU).
// You should initialize an object of this type for each utterance you want
// to decode.
// Note: pad_input will normally be true; it means that at the start and end
// of the file, we pad with repeats of the first/last frame, so that the total
// number of frames it outputs is the same as the number of input frames.
NnetOnlineComputer(const Nnet &nnet,
bool pad_input);
// This function works as follows: given a chunk of input (interpreted
// as following in time any previously supplied data), do the computation
// and produce all the frames of output we can. In the middle of the
// file, the dimensions of input and output will be the same, but at
// the beginning of the file, output will have fewer frames than input
// due to required context.
// It is the responsibility of the user to keep track of frame indices, if
// required. This class won't output any frame twice.
void Compute(const CuMatrixBase<BaseFloat> &input,
CuMatrix<BaseFloat> *output);
// This flushes out the last frames of output; you call this when all
// input has finished. It's invalid to call Compute or Flush after
// calling Flush. It's valid to call Flush if no frames have been
// input or if no frames have been output; this produces empty output.
void Flush(CuMatrix<BaseFloat> *output);
private:
void Propagate();
const Nnet &nnet_;
// data_ contains the intermediate stages and the output of the most recent
// computation.
std::vector<CuMatrix<BaseFloat> > data_;
std::vector<ChunkInfo> chunk_info_;
CuMatrix<BaseFloat> unused_input_;
bool pad_input_;
bool is_first_chunk_;
bool finished_;
// we might need more variables here to keep track of how many frames we
// already output from data_.
KALDI_DISALLOW_COPY_AND_ASSIGN(NnetOnlineComputer);
};
} // namespace nnet2
} // namespace kaldi
#endif // KALDI_NNET2_NNET_COMPUTE_ONLINE_H_

Просмотреть файл

@ -30,6 +30,10 @@
namespace kaldi {
namespace nnet2 {
// Note: see also nnet-compute-online.h, which provides a different
// (lower-level) interface and more efficient for progressive evaluation of an
// nnet throughout an utterance, with re-use of already-computed activations.
struct DecodableNnet2OnlineOptions {
BaseFloat acoustic_scale;
bool pad_input;
@ -55,6 +59,12 @@ struct DecodableNnet2OnlineOptions {
};
/**
This Decodable object for class nnet2::AmNnet takes feature input from class
OnlineFeatureInterface, unlike, say, class DecodableAmNnet which takes
feature input from a matrix.
*/
class DecodableNnet2Online: public DecodableInterface {
public:
DecodableNnet2Online(const AmNnet &nnet,

Просмотреть файл

@ -7,7 +7,8 @@ TESTFILES =
OBJFILES = online-gmm-decodable.o online-feature-pipeline.o online-ivector-feature.o \
online-nnet2-feature-pipeline.o online-gmm-decoding.o online-timing.o \
online-endpoint.o onlinebin-util.o online-speex-wrapper.o online-nnet2-decoding.o
online-endpoint.o onlinebin-util.o online-speex-wrapper.o \
online-nnet2-decoding.o online-nnet2-decoding-threaded.o
LIBNAME = kaldi-online2

Просмотреть файл

@ -0,0 +1,652 @@
// online2/online-nnet2-decoding-threaded.cc
// Copyright 2013-2014 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "online2/online-nnet2-decoding-threaded.h"
#include "nnet2/nnet-compute-online.h"
#include "lat/lattice-functions.h"
#include "lat/determinize-lattice-pruned.h"
#include "thread/kaldi-thread.h"
namespace kaldi {
ThreadSynchronizer::ThreadSynchronizer():
abort_(false),
producer_waiting_(false),
consumer_waiting_(false),
num_errors_(0) {
producer_semaphore_.Signal();
consumer_semaphore_.Signal();
}
bool ThreadSynchronizer::Lock(ThreadType t) {
if (abort_)
return false;
if (t == ThreadSynchronizer::kProducer) {
producer_semaphore_.Wait();
} else {
consumer_semaphore_.Wait();
}
if (abort_)
return false;
mutex_.Lock();
held_by_ = t;
if (abort_) {
mutex_.Unlock();
return false;
} else {
return true;
}
}
bool ThreadSynchronizer::UnlockSuccess(ThreadType t) {
if (t == ThreadSynchronizer::kProducer) {
producer_semaphore_.Signal(); // next Lock won't wait.
if (consumer_waiting_) {
consumer_semaphore_.Signal();
consumer_waiting_ = false;
}
} else {
consumer_semaphore_.Signal(); // next Lock won't wait.
if (producer_waiting_) {
producer_semaphore_.Signal();
producer_waiting_ = false;
}
}
mutex_.Unlock();
return !abort_;
}
bool ThreadSynchronizer::UnlockFailure(ThreadType t) {
KALDI_ASSERT(held_by_ == t && "Code error: unlocking a mutex you don't hold.");
if (t == ThreadSynchronizer::kProducer) {
KALDI_ASSERT(!producer_waiting_ && "code error.");
producer_waiting_ = true;
} else {
KALDI_ASSERT(!consumer_waiting_ && "code error.");
consumer_waiting_ = true;
}
mutex_.Unlock();
return !abort_;
}
void ThreadSynchronizer::SetAbort() {
abort_ = true;
// we signal the semaphores just in case someone was waiting on either of
// them.
producer_semaphore_.Signal();
consumer_semaphore_.Signal();
}
ThreadSynchronizer::~ThreadSynchronizer() {
}
// static
void OnlineNnet2DecodingThreadedConfig::Check() {
KALDI_ASSERT(max_buffered_features > 1);
KALDI_ASSERT(feature_batch_size > 0);
KALDI_ASSERT(max_loglikes_copy >= 0);
KALDI_ASSERT(nnet_batch_size > 0);
KALDI_ASSERT(decode_batch_size >= 1);
}
SingleUtteranceNnet2DecoderThreaded::SingleUtteranceNnet2DecoderThreaded(
const OnlineNnet2DecodingThreadedConfig &config,
const TransitionModel &tmodel,
const nnet2::AmNnet &am_nnet,
const fst::Fst<fst::StdArc> &fst,
const OnlineNnet2FeaturePipelineInfo &feature_info,
const OnlineIvectorExtractorAdaptationState &adaptation_state):
config_(config), am_nnet_(am_nnet), tmodel_(tmodel), sampling_rate_(0.0),
num_samples_received_(0), input_finished_(false),
feature_pipeline_(feature_info), feature_buffer_start_frame_(0),
feature_buffer_finished_(false), decodable_(tmodel),
num_frames_decoded_(0), decoder_(fst, config_.decoder_opts),
abort_(false), error_(false) {
// if the user supplies an adaptation state that was not freshly initialized,
// it means that we take the adaptation state from the previous
// utterance(s)... this only makes sense if theose previous utterance(s) are
// believed to be from the same speaker.
feature_pipeline_.SetAdaptationState(adaptation_state);
// spawn threads.
pthread_attr_t pthread_attr;
pthread_attr_init(&pthread_attr);
int32 ret;
// Note: if the constructor throws an exception, the corresponding destructor
// will not be called. So we don't have to be careful about setting the
// thread pointers to NULL after we've joined them.
if ((ret=pthread_create(&(threads_[0]),
&pthread_attr, RunFeatureExtraction,
(void*)this)) != 0) {
const char *c = strerror(ret);
if (c == NULL) { c = "[NULL]"; }
KALDI_ERR << "Error creating thread, errno was: " << c;
}
if ((ret=pthread_create(&(threads_[1]),
&pthread_attr, RunNnetEvaluation,
(void*)this)) != 0) {
const char *c = strerror(ret);
if (c == NULL) { c = "[NULL]"; }
bool error = true;
AbortAllThreads(error);
KALDI_WARN << "Error creating thread, errno was: " << c
<< " (will rejoin already-created threads).";
if (pthread_join(threads_[0], NULL)) {
KALDI_ERR << "Error rejoining thread.";
} else {
KALDI_ERR << "Error creating thread, errno was: " << c;
}
}
if ((ret=pthread_create(&(threads_[2]),
&pthread_attr, RunDecoderSearch,
(void*)this)) != 0) {
const char *c = strerror(ret);
if (c == NULL) { c = "[NULL]"; }
bool error = true;
AbortAllThreads(error);
KALDI_WARN << "Error creating thread, errno was: " << c
<< " (will rejoin already-created threads).";
if (pthread_join(threads_[0], NULL) || pthread_join(threads_[1], NULL)) {
KALDI_ERR << "Error rejoining thread.";
} else {
KALDI_ERR << "Error creating thread, errno was: " << c;
}
}
}
SingleUtteranceNnet2DecoderThreaded::~SingleUtteranceNnet2DecoderThreaded() {
if (!abort_) {
// If we have not already started the process of aborting the threads, do so now.
bool error = false;
AbortAllThreads(error);
}
// join all the threads (this avoids leaving zombie threads around, or threads
// that might be accessing deconstructed object).
WaitForAllThreads();
DeletePointers(&input_waveform_);
DeletePointers(&feature_buffer_);
}
void SingleUtteranceNnet2DecoderThreaded::AcceptWaveform(
BaseFloat sampling_rate,
const VectorBase<BaseFloat> &wave_part) {
if (sampling_rate_ <= 0.0)
sampling_rate_ = sampling_rate;
else {
KALDI_ASSERT(sampling_rate == sampling_rate_);
}
num_samples_received_ += wave_part.Dim();
if (wave_part.Dim() == 0) return;
if (!waveform_synchronizer_.Lock(ThreadSynchronizer::kProducer)) {
KALDI_ERR << "Failure locking mutex: decoding aborted.";
}
Vector<BaseFloat> *new_part = new Vector<BaseFloat>(wave_part);
input_waveform_.push_back(new_part);
// we always unlock with success because there is no buffer size limitation
// for the waveform so no reason why we might wait.
waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kProducer);
}
int32 SingleUtteranceNnet2DecoderThreaded::NumFramesReceivedApprox() const {
return num_samples_received_ /
(sampling_rate_ * feature_pipeline_.FrameShiftInSeconds());
}
void SingleUtteranceNnet2DecoderThreaded::InputFinished() {
// setting input_finished_ = true informs the feature-processing pipeline
// to expect no more input, and to flush out the last few frames if there
// is any latency in the pipeline (e.g. due to pitch).
if (!waveform_synchronizer_.Lock(ThreadSynchronizer::kProducer)) {
KALDI_ERR << "Failure locking mutex: decoding aborted.";
}
KALDI_ASSERT(!input_finished_ && "InputFinished called twice");
input_finished_ = true;
waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kProducer);
}
void SingleUtteranceNnet2DecoderThreaded::TerminateDecoding() {
bool error = false;
AbortAllThreads(error);
}
void SingleUtteranceNnet2DecoderThreaded::Wait() {
if (!input_finished_ && !abort_) {
KALDI_ERR << "You cannot call Wait() before calling either InputFinished() "
<< "or TerminateDecoding().";
}
WaitForAllThreads();
}
void SingleUtteranceNnet2DecoderThreaded::FinalizeDecoding() {
if (KALDI_PTHREAD_PTR(threads_[0]) != 0) {
KALDI_ERR << "It is an error to call FinalizeDecoding before Wait().";
}
decoder_.FinalizeDecoding();
}
void SingleUtteranceNnet2DecoderThreaded::GetAdaptationState(
OnlineIvectorExtractorAdaptationState *adaptation_state) {
feature_pipeline_mutex_.Lock(); // If this blocks, it shouldn't be for very long.
feature_pipeline_.GetAdaptationState(adaptation_state);
feature_pipeline_mutex_.Unlock(); // If this blocks, it won't be for very long.
}
void SingleUtteranceNnet2DecoderThreaded::GetLattice(
bool end_of_utterance,
CompactLattice *clat,
BaseFloat *final_relative_cost) const {
clat->DeleteStates();
// we'll make an exception to the normal const rules, for mutexes, since
// we're not really changing the class.
const_cast<Mutex&>(decoder_mutex_).Lock();
if (final_relative_cost != NULL)
*final_relative_cost = decoder_.FinalRelativeCost();
if (decoder_.NumFramesDecoded() == 0) {
const_cast<Mutex&>(decoder_mutex_).Unlock();
clat->SetFinal(clat->AddState(),
CompactLatticeWeight::One());
return;
}
Lattice raw_lat;
decoder_.GetRawLattice(&raw_lat, end_of_utterance);
const_cast<Mutex&>(decoder_mutex_).Unlock();
if (!config_.decoder_opts.determinize_lattice)
KALDI_ERR << "--determinize-lattice=false option is not supported at the moment";
BaseFloat lat_beam = config_.decoder_opts.lattice_beam;
DeterminizeLatticePhonePrunedWrapper(
tmodel_, &raw_lat, lat_beam, clat, config_.decoder_opts.det_opts);
}
void SingleUtteranceNnet2DecoderThreaded::GetBestPath(
bool end_of_utterance,
Lattice *best_path,
BaseFloat *final_relative_cost) const {
// we'll make an exception to the normal const rules, for mutexes, since
// we're not really changing the class.
const_cast<Mutex&>(decoder_mutex_).Lock();
if (decoder_.NumFramesDecoded() == 0) {
// It's possible that this if-statement is not necessary because we'd get this
// anyway if we just called GetBestPath on the decoder.
best_path->DeleteStates();
best_path->SetFinal(best_path->AddState(),
LatticeWeight::One());
if (final_relative_cost != NULL)
*final_relative_cost = std::numeric_limits<BaseFloat>::infinity();
} else {
decoder_.GetBestPath(best_path,
end_of_utterance);
if (final_relative_cost != NULL)
*final_relative_cost = decoder_.FinalRelativeCost();
}
const_cast<Mutex&>(decoder_mutex_).Unlock();
}
void SingleUtteranceNnet2DecoderThreaded::AbortAllThreads(bool error) {
abort_ = true;
if (error)
error_ = true;
waveform_synchronizer_.SetAbort();
feature_synchronizer_.SetAbort();
decodable_synchronizer_.SetAbort();
}
int32 SingleUtteranceNnet2DecoderThreaded::NumFramesDecoded() const {
const_cast<Mutex&>(decoder_mutex_).Lock();
int32 ans = decoder_.NumFramesDecoded();
const_cast<Mutex&>(decoder_mutex_).Unlock();
return ans;
}
void* SingleUtteranceNnet2DecoderThreaded::RunFeatureExtraction(void *ptr_in) {
SingleUtteranceNnet2DecoderThreaded *me =
reinterpret_cast<SingleUtteranceNnet2DecoderThreaded*>(ptr_in);
try {
if (!me->RunFeatureExtractionInternal() && !me->abort_)
KALDI_ERR << "Returned abnormally and abort was not called";
} catch(const std::exception &e) {
KALDI_WARN << "Caught exception: " << e.what();
// if an error happened in one thread, we need to make sure the other
// threads can exit too.
bool error = true;
me->AbortAllThreads(error);
}
return NULL;
}
void* SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluation(void *ptr_in) {
SingleUtteranceNnet2DecoderThreaded *me =
reinterpret_cast<SingleUtteranceNnet2DecoderThreaded*>(ptr_in);
try {
if (!me->RunNnetEvaluationInternal() && !me->abort_)
KALDI_ERR << "Returned abnormally and abort was not called";
} catch(const std::exception &e) {
KALDI_WARN << "Caught exception: " << e.what();
// if an error happened in one thread, we need to make sure the other
// threads can exit too.
bool error = true;
me->AbortAllThreads(error);
}
return NULL;
}
void* SingleUtteranceNnet2DecoderThreaded::RunDecoderSearch(void *ptr_in) {
SingleUtteranceNnet2DecoderThreaded *me =
reinterpret_cast<SingleUtteranceNnet2DecoderThreaded*>(ptr_in);
try {
if (!me->RunDecoderSearchInternal() && !me->abort_)
KALDI_ERR << "Returned abnormally and abort was not called";
} catch(const std::exception &e) {
KALDI_WARN << "Caught exception: " << e.what();
// if an error happened in one thread, we need to make sure the other threads can exit too.
bool error = true;
me->AbortAllThreads(error);
}
return NULL;
}
void SingleUtteranceNnet2DecoderThreaded::WaitForAllThreads() {
for (int32 i = 0; i < 3; i++) { // there are 3 spawned threads.
pthread_t &thread = threads_[i];
if (KALDI_PTHREAD_PTR(thread) != 0) {
if (pthread_join(thread, NULL)) {
KALDI_ERR << "Error rejoining thread"; // this should not happen.
}
KALDI_PTHREAD_PTR(thread) = 0;
}
}
if (error_) {
KALDI_ERR << "Error encountered during decoding. See above.";
}
}
bool SingleUtteranceNnet2DecoderThreaded::RunFeatureExtractionInternal() {
// Note: if any of the functions Lock, UnlockSuccess, UnlockFailure return
// false, it is because AbortAllThreads() called, and we return false
// immediately.
// num_frames_output is a local variable that keeps track of how many
// frames we have output to the feature buffer, for this utterance.
int32 num_frames_output = 0;
while (true) {
// First deal with accepting input.
if (!waveform_synchronizer_.Lock(ThreadSynchronizer::kConsumer))
return false;
if (input_waveform_.empty()) {
if (input_finished_ &&
!feature_pipeline_.IsLastFrame(feature_pipeline_.NumFramesReady()-1)) {
// the main thread called InputFinished() and set input_finished_, and
// we haven't yet registered that fact. This is progress so
// UnlockSuccess().
feature_pipeline_.InputFinished();
if (!waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
return false;
} else {
// there was no input to process. However, we only call UnlockFailure() if we
// are blocked on the fact that there is no input to process; otherwise we
// call UnlockSuccess().
if (num_frames_output == feature_pipeline_.NumFramesReady()) {
// we need to wait until there is more input.
if (!waveform_synchronizer_.UnlockFailure(ThreadSynchronizer::kConsumer))
return false;
} else { // we can keep looping.
if (!waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
return false;
}
}
} else { // there is more wav data.
{ // braces clarify scope of locking.
feature_pipeline_mutex_.Lock();
for (size_t i = 0; i < input_waveform_.size(); i++)
if (input_waveform_[i]->Dim() != 0)
feature_pipeline_.AcceptWaveform(sampling_rate_, *input_waveform_[i]);
feature_pipeline_mutex_.Unlock();
}
DeletePointers(&input_waveform_);
input_waveform_.clear();
if (!waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
return false;
}
if (!feature_synchronizer_.Lock(ThreadSynchronizer::kProducer)) return false;
if (feature_buffer_.size() >= config_.max_buffered_features) {
// we need to block on the output buffer.
if (!feature_synchronizer_.UnlockFailure(ThreadSynchronizer::kProducer))
return false;
} else {
{ // braces clarify scope of locking.
feature_pipeline_mutex_.Lock();
// There is buffer space available; deal with producing output.
int32 cur_size = feature_buffer_.size(),
batch_size = config_.feature_batch_size,
feat_dim = feature_pipeline_.Dim();
for (int32 t = feature_buffer_start_frame_ +
static_cast<int32>(feature_buffer_.size());
t < feature_buffer_start_frame_ + config_.max_buffered_features &&
t < feature_buffer_start_frame_ + cur_size + batch_size &&
t < feature_pipeline_.NumFramesReady(); t++) {
Vector<BaseFloat> *feats = new Vector<BaseFloat>(feat_dim, kUndefined);
// Note: most of the actual computation occurs.
feature_pipeline_.GetFrame(t, feats);
feature_buffer_.push_back(feats);
}
num_frames_output = feature_buffer_start_frame_ + feature_buffer_.size();
if (feature_pipeline_.IsLastFrame(num_frames_output - 1)) {
// e.g. user called InputFinished() and we already saw the last frame.
feature_buffer_finished_ = true;
}
feature_pipeline_mutex_.Unlock();
}
if (!feature_synchronizer_.UnlockSuccess(ThreadSynchronizer::kProducer))
return false;
if (feature_buffer_finished_) return true;
}
}
}
bool SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluationInternal() {
// if any of the Lock/Unlock functions return false, it's because AbortAllThreads()
// was called.
// This object is responsible for keeping track of the context, and avoiding
// re-computing things we've already computed.
bool pad_input = true;
nnet2::NnetOnlineComputer computer(am_nnet_.GetNnet(), pad_input);
// we declare the following as CuVector just to enable GPU support, but
// we expect this code to be run on CPU in the normal case.
CuVector<BaseFloat> log_inv_priors(am_nnet_.Priors());
log_inv_priors.ApplyFloor(1.0e-20); // should have no effect.
log_inv_priors.ApplyLog();
log_inv_priors.Scale(-1.0);
int32 num_frames_output = 0;
while (true) {
bool last_time = false;
if (!feature_synchronizer_.Lock(ThreadSynchronizer::kConsumer))
return false;
CuMatrix<BaseFloat> cu_loglikes;
if (feature_buffer_.empty()) {
if (feature_buffer_finished_) {
// flush out the last few frames. Note: this is the only place from
// which we check feature_buffer_finished_, and we'll exit the loop, so
// if we reach here it must be the first time it was true.
last_time = true;
if (!feature_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
return false;
computer.Flush(&cu_loglikes);
} else {
// there is nothing to do because there is no input. Next call to Lock
// should block till the feature-processing thread does something.
if (!feature_synchronizer_.UnlockFailure(ThreadSynchronizer::kConsumer))
return false;
}
} else {
int32 num_frames_evaluate = std::min<int32>(feature_buffer_.size(),
config_.nnet_batch_size),
feat_dim = feature_buffer_[0]->Dim();
Matrix<BaseFloat> feats(num_frames_evaluate, feat_dim);
for (int32 i = 0; i < num_frames_evaluate; i++) {
feats.Row(i).CopyFromVec(*(feature_buffer_[i]));
delete feature_buffer_[i];
}
feature_buffer_.erase(feature_buffer_.begin(),
feature_buffer_.begin() + num_frames_evaluate);
feature_buffer_start_frame_ += num_frames_evaluate;
if (!feature_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
return false;
CuMatrix<BaseFloat> cu_feats;
cu_feats.Swap(&feats); // If we don't have a GPU (and not having a GPU is
// the normal expected use-case for this code),
// this would be a lightweight operation, swapping
// pointers.
KALDI_VLOG(4) << "Computing chunk of " << cu_feats.NumRows() << " frames "
<< "of nnet.";
computer.Compute(cu_feats, &cu_loglikes);
cu_loglikes.ApplyFloor(1.0e-20);
cu_loglikes.ApplyLog();
// take the log-posteriors and turn them into pseudo-log-likelihoods by
// dividing by the pdf priors; then scale by the acoustic scale.
if (cu_loglikes.NumRows() != 0) {
cu_loglikes.AddVecToRows(1.0, log_inv_priors);
cu_loglikes.Scale(config_.acoustic_scale);
}
}
Matrix<BaseFloat> loglikes;
loglikes.Swap(&cu_loglikes); // If we don't have a GPU (and not having a
// GPU is the normal expected use-case for
// this code), this would be a lightweight
// operation, swapping pointers.
// OK, at this point we may have some newly created log-likes and we want to
// give them to the decoding thread.
int32 num_loglike_frames = loglikes.NumRows();
if (loglikes.NumRows() != 0) { // if we need to output some loglikes...
while (true) {
// we may have to grab and release the decodable mutex
// a few times before it's ready to accept the loglikes.
if (!decodable_synchronizer_.Lock(ThreadSynchronizer::kProducer))
return false;
int32 num_frames_decoded = num_frames_decoded_;
// we can't have output fewer frames than were decoded.
KALDI_ASSERT(num_frames_output >= num_frames_decoded);
if (num_frames_output - num_frames_decoded < config_.max_loglikes_copy) {
// If we would have to copy fewer than config_.max_loglikes_copy
// previously evaluated log-likelihoods inside the decodable object..
int32 frames_to_discard = num_frames_decoded_ -
decodable_.FirstAvailableFrame();
KALDI_ASSERT(frames_to_discard >= 0);
num_frames_output += num_loglike_frames;
decodable_.AcceptLoglikes(&loglikes, frames_to_discard);
if (!decodable_synchronizer_.UnlockSuccess(ThreadSynchronizer::kProducer))
return false;
break; // break from the innermost while loop.
} else {
// we want the next call to Lock to block until the decoder has
// processed more frames.
if (!decodable_synchronizer_.UnlockFailure(ThreadSynchronizer::kProducer))
return false;
}
}
}
if (last_time) {
// Inform the decodable object that there will be no more input.
if (!decodable_synchronizer_.Lock(ThreadSynchronizer::kProducer))
return false;
decodable_.InputIsFinished();
if (!decodable_synchronizer_.UnlockSuccess(ThreadSynchronizer::kProducer))
return false;
return true;
}
}
}
bool SingleUtteranceNnet2DecoderThreaded::RunDecoderSearchInternal() {
int32 num_frames_decoded = 0; // this is just a copy of decoder_->NumFramesDecoded();
decoder_.InitDecoding();
while (true) { // decode at most one frame each loop.
if (!decodable_synchronizer_.Lock(ThreadSynchronizer::kConsumer))
return false; // AbortAllThreads() called.
if (decodable_.NumFramesReady() <= num_frames_decoded) {
// no frames available to decode.
KALDI_ASSERT(decodable_.NumFramesReady() == num_frames_decoded);
if (decodable_.IsLastFrame(num_frames_decoded - 1)) {
decodable_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer);
return true; // exit from this thread; we're done.
} else {
// we were not able to advance the decoding due to no available
// input. The next call will ensure that the next call to
// decodable_synchronizer_.Lock() will wait.
if (!decodable_synchronizer_.UnlockFailure(ThreadSynchronizer::kConsumer))
return false;
}
} else {
// Decode at most config_.decode_batch_size frames (e.g. 1 or 2).
decoder_mutex_.Lock();
decoder_.AdvanceDecoding(&decodable_, config_.decode_batch_size);
num_frames_decoded = decoder_.NumFramesDecoded();
decoder_mutex_.Unlock();
num_frames_decoded_ = num_frames_decoded;
if (!decodable_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
return false;
}
}
}
bool SingleUtteranceNnet2DecoderThreaded::EndpointDetected(
const OnlineEndpointConfig &config) {
decoder_mutex_.Lock();
bool ans = kaldi::EndpointDetected(config, tmodel_,
feature_pipeline_.FrameShiftInSeconds(),
decoder_);
decoder_mutex_.Unlock();
return ans;
}
} // namespace kaldi

Просмотреть файл

@ -0,0 +1,415 @@
// online2/online-nnet2-decoding-threaded.h
// Copyright 2014-2015 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_ONLINE2_ONLINE_NNET2_DECODING_THREADED_H_
#define KALDI_ONLINE2_ONLINE_NNET2_DECODING_THREADED_H_
#include <string>
#include <vector>
#include <deque>
#include "matrix/matrix-lib.h"
#include "util/common-utils.h"
#include "base/kaldi-error.h"
#include "decoder/decodable-matrix.h"
#include "nnet2/am-nnet.h"
#include "online2/online-nnet2-feature-pipeline.h"
#include "online2/online-endpoint.h"
#include "decoder/lattice-faster-online-decoder.h"
#include "hmm/transition-model.h"
#include "thread/kaldi-mutex.h"
#include "thread/kaldi-semaphore.h"
namespace kaldi {
/// @addtogroup onlinedecoding OnlineDecoding
/// @{
/**
class ThreadSynchronizer acts to guard an arbitrary type of buffer between a
producing and a consuming thread (note: it's all symmetric between the two
thread types). It has a similar interface to a mutex, except that instead of
just Lock and Unlock, it has Lock, UnlockSuccess and UnlockFailure, and each
function takes an argument kProducer or kConsumer to identify whether the
producing or consuming thread is waiting.
The basic concept is that you lock the object; and if you discover the you're
blocked because you're either trying to read an empty buffer or trying to
write to a full buffer, you unlock with UnlockFailure; and this will cause
your next call to Lock to block until the *other* thread has called Lock and
then UnlockSuccess. However, if at that point the other thread calls Lock
and then UnlockFailure, it is an error because you can't have both producing
and consuming threads claiming that the buffer is full/empty. If you lock
the object and were successful you call UnlockSuccess; and you call
UnlockSuccess even if, for your own reasons, you ended up not changing the
state of the buffer.
*/
class ThreadSynchronizer {
public:
ThreadSynchronizer();
// Most calls to this class should provide the thread-type of the caller,
// producing or consuming. Actually the behavior of this class is symmetric
// between the two types of thread.
enum ThreadType { kProducer, kConsumer };
// All functions returning bool will return true normally, and false if
// SetAbort() was set; if they return false, you should probably call SetAbort()
// on any other ThreadSynchronizer classes you are using and then return from
// the thread.
// call this to lock the object being guarded.
bool Lock(ThreadType t);
// Call this to unlock the object being guarded, if you don't want the next call to
// Lock to stall.
bool UnlockSuccess(ThreadType t);
// Call this if you want the next call to Lock() to stall until the other
// (producer/consumer) thread has locked and then unlocked the mutex. Note
// that, if the other thread then calls Lock and then UnlockFailure, this will
// generate a printed warning (and if repeated too many times, an exception).
bool UnlockFailure(ThreadType t);
// Sets abort_ flag so future calls will return false, and future calls to
// Lock() won't lock the mutex but will immediately return false.
void SetAbort();
~ThreadSynchronizer();
private:
bool abort_;
bool producer_waiting_; // true if producer is/will be waiting on semaphore
bool consumer_waiting_; // true if consumer is/will be waiting on semaphore
Mutex mutex_; // Locks the buffer object.
ThreadType held_by_; // Record of which thread is holding the mutex (if
// held); else undefined. Used for validation of input.
Semaphore producer_semaphore_; // The producer thread waits on this semaphore
Semaphore consumer_semaphore_; // The consumer thread waits on this semaphore
int32 num_errors_; // Rumber of times the threads alternated doing Lock() and
// UnlockFailure(). This should not happen at all; but
// it's more user-friendly to simply warn a few times; and then
// only after a number of errors, to fail.
KALDI_DISALLOW_COPY_AND_ASSIGN(ThreadSynchronizer);
};
// This is the configuration class for SingleUtteranceNnet2DecoderThreaded. The
// actual command line program requires other configs that it creates
// separately, and which are not included here: namely,
// OnlineNnet2FeaturePipelineConfig and OnlineEndpointConfig.
struct OnlineNnet2DecodingThreadedConfig {
LatticeFasterDecoderConfig decoder_opts;
BaseFloat acoustic_scale;
int32 max_buffered_features; // maximum frames of features we allow to be
// held in the feature buffer before we block
// the feature-processing thread.
int32 feature_batch_size; // maximum number of frames at a time that we decode
// before unlocking the mutex. The only real cost
// here is a mutex lock/unlock, so it's OK to make
// this fairly small.
int32 max_loglikes_copy; // maximum unused frames of log-likelihoods we will
// copy from the decodable object back into another
// matrix to be supplied to the decodable object.
// make this too large-> will block the
// decoder-search thread while copying; too small
// -> the nnet-evaluation thread may get blocked
// for too long while waiting for the decodable
// thread to be ready.
int32 nnet_batch_size; // batch size (number of frames) we evaluate in the
// neural net, if this many is available. To take
// best advantage of BLAS, you may want to set this
// fairly large, e.g. 32 or 64 frames. It probably
// makes sense to tune this a bit.
int32 decode_batch_size; // maximum number of frames at a time that we decode
// before unlocking the mutex. The only real cost
// here is a mutex lock/unlock, so it's OK to make
// this fairly small.
OnlineNnet2DecodingThreadedConfig() {
acoustic_scale = 0.1;
max_buffered_features = 100;
feature_batch_size = 2;
nnet_batch_size = 32;
max_loglikes_copy = 20;
decode_batch_size = 2;
}
void Check();
void Register(OptionsItf *po) {
decoder_opts.Register(po);
po->Register("acoustic-scale", &acoustic_scale, "Scale used on acoustics "
"when decoding");
po->Register("max-buffered-features", &max_buffered_features, "Obscure "
"setting, affects multi-threaded decoding.");
po->Register("feature-batch-size", &max_buffered_features, "Obscure "
"setting, affects multi-threaded decoding.");
po->Register("nnet-batch-size", &nnet_batch_size, "Maximum batch size "
"(in frames) used when evaluating neural net likelihoods");
po->Register("max-loglikes-copy", &max_loglikes_copy, "Obscure "
"setting, affects multi-threaded decoding.");
po->Register("decode-batch-sie", &decode_batch_size, "Obscure "
"setting, affects multi-threaded decoding.");
}
};
/**
You will instantiate this class when you want to decode a single
utterance using the online-decoding setup for neural nets. Each time this
class is created, it creates three background threads, and the feature
extraction, neural net evaluation, and search aspects of decoding all
happen in different threads.
Note: we assume that all calls to its public interface happen from a single
thread.
*/
class SingleUtteranceNnet2DecoderThreaded {
public:
// Constructor. Unlike SingleUtteranceNnet2Decoder, we create the
// feature_pipeline object inside this class, since access to it needs to be
// controlled by a mutex and this class knows how to handle that. The
// feature_info and adaptation_state arguments are used to initialize the
// (locally owned) feature pipeline.
SingleUtteranceNnet2DecoderThreaded(
const OnlineNnet2DecodingThreadedConfig &config,
const TransitionModel &tmodel,
const nnet2::AmNnet &am_nnet,
const fst::Fst<fst::StdArc> &fst,
const OnlineNnet2FeaturePipelineInfo &feature_info,
const OnlineIvectorExtractorAdaptationState &adaptation_state);
/// You call this to provide this class with more waveform to decode. This
/// call is, for all practical purposes, non-blocking.
void AcceptWaveform(BaseFloat samp_freq,
const VectorBase<BaseFloat> &wave_part);
/// You call this to inform the class that no more waveform will be provided;
/// this allows it to flush out the last few frames of features, and is
/// necessary if you want to call Wait() to wait until all decoding is done.
/// After calling InputFinished() you cannot call AcceptWaveform any more.
void InputFinished();
/// You can call this if you don't want the decoding to proceed further with
/// this utterance. It just won't do any more processing, but you can still
/// use the lattice from the decoding that it's already done. Note: it may
/// still continue decoding up to decode_batch_size (default: 2) frames of
/// data before the decoding thread exits. You can call Wait() after calling
/// this, if you want to wait for that.
void TerminateDecoding();
/// This call will block until all the data has been decoded; it must only be
/// called after either InputFinished() has been called or TerminateDecoding() has
/// been called; otherwise, to call it is an error.
void Wait();
/// Finalizes the decoding. Cleans up and prunes remaining tokens, so the final
/// lattice is faster to obtain. May not be called unless either InputFinished()
/// or TerminateDecoding() has been called. If InputFinished() was called, it
/// calls Wait() to ensure that the decoding has finished (it's not an error
/// if you already called Wait()).
void FinalizeDecoding();
/// Returns *approximately* (ignoring end effects), the number of frames of
/// data that we expect given the amount of data that the pipeline has
/// received via AcceptWaveform(). (ignores small end effects). This might
/// be useful in application code to compare with NumFramesDecoded() and gauge
/// how much latency there is.
int32 NumFramesReceivedApprox() const;
/// Returns the number of frames currently decoded. Caution: don't rely on
/// the lattice having exactly this number if you get it after this call, as
/// it may increase after this-- unless you've already called either
/// TerminateDecoding() or InputFinished(), followed by Wait().
int32 NumFramesDecoded() const;
/// Gets the lattice. The output lattice has any acoustic scaling in it
/// (which will typically be desirable in an online-decoding context); if you
/// want an un-scaled lattice, scale it using ScaleLattice() with the inverse
/// of the acoustic weight. "end_of_utterance" will be true if you want the
/// final-probs to be included. If this is at the end of the utterance,
/// you might want to first call FinalizeDecoding() first; this will make this
/// call return faster.
/// If no frames have been decoded yet, it will set clat to a lattice with
/// a single state that is final and with unit weight (no cost or alignment).
/// The output to final_relative_cost (if non-NULL) is a number >= 0 that's
/// closer to 0 if a final-state was close to the best-likelihood state
/// active on the last frame, at the time we obtained the lattice.
void GetLattice(bool end_of_utterance,
CompactLattice *clat,
BaseFloat *final_relative_cost) const;
/// Outputs an FST corresponding to the single best path through the current
/// lattice. If "use_final_probs" is true AND we reached the final-state of
/// the graph then it will include those as final-probs, else it will treat
/// all final-probs as one.
/// If no frames have been decoded yet, it will set best_path to a lattice with
/// a single state that is final and with unit weight (no cost).
/// The output to final_relative_cost (if non-NULL) is a number >= 0 that's
/// closer to 0 if a final-state were close to the best-likelihood state
/// active on the last frame, at the time we got the best path.
void GetBestPath(bool end_of_utterance,
Lattice *best_path,
BaseFloat *final_relative_cost) const;
/// This function calls EndpointDetected from online-endpoint.h,
/// with the required arguments.
bool EndpointDetected(const OnlineEndpointConfig &config);
/// Outputs the adaptation state of the feature pipeline to "adaptation_state". This
/// mostly stores stats for iVector estimation, and will generally be called at the
/// end of an utterance, assuming it's a scenario where each speaker is seen for
/// more than one utterance.
/// You may only call this function after either calling TerminateDecoding() or
/// InputFinished, and then Wait(). Otherwise it is an error.
void GetAdaptationState(OnlineIvectorExtractorAdaptationState *adaptation_state);
~SingleUtteranceNnet2DecoderThreaded();
private:
// This function will instruct all threads to abort operation as soon as they
// can safely do so, by calling SetAbort() in the threads
void AbortAllThreads(bool error);
// This function waits for all the threads that have been spawned, and then
// sets the pointers in threads_ to NULL; it is called in the destructor and
// from Wait(). If called twice it is not an error.
void WaitForAllThreads();
// this function runs the thread that does the feature extraction; ptr_in is
// to class SingleUtteranceNnet2DecoderThreaded. Always returns NULL, but in
// case of failure, calls ptr_in->AbortAllThreads(true).
static void* RunFeatureExtraction(void *ptr_in);
// member-function version of RunFeatureExtraction, called by RunFeatureExgtraction.
bool RunFeatureExtractionInternal();
// this function runs the thread that does the neural-net evaluation ptr_in is
// to class SingleUtteranceNnet2DecoderThreaded. Always returns NULL, but in
// case of failure, calls ptr_in->AbortAllThreads(true).
static void* RunNnetEvaluation(void *ptr_in);
// member-function version of RunNnetEvaluation, called by RunNnetEvaluation.
bool RunNnetEvaluationInternal();
// this function runs the thread that does the neural-net evaluation ptr_in is
// to class SingleUtteranceNnet2DecoderThreaded. Always returns NULL, but in
// case of failure, calls ptr_in->AbortAllThreads(true).
static void* RunDecoderSearch(void *ptr_in);
// member-function version of RunDecoderSearch, called by RunDecoderSearch.
bool RunDecoderSearchInternal();
// Member variables:
OnlineNnet2DecodingThreadedConfig config_;
const nnet2::AmNnet &am_nnet_;
const TransitionModel &tmodel_;
// sampling_rate_ is set the first time AcceptWaveform is called.
BaseFloat sampling_rate_;
// A record of how many samples have been provided so
// far via calls to AcceptWaveform.
int64 num_samples_received_;
// The next two variables are written to by AcceptWaveform from the main
// thread, and read by the feature-processing thread; they are guarded by
// waveform_synchronizer_. There is no bound on the buffer size here.
// Later-arriving data is appended to the vector. When InputFinished() is
// called from the main thread, the main thread sets input_finished_ = true.
// sampling_rate_ is only needed for checking that it matches the config.
bool input_finished_;
std::vector< Vector<BaseFloat>* > input_waveform_;
ThreadSynchronizer waveform_synchronizer_;
// feature_pipeline_ is only accessed by the feature-processing thread, and by
// the main thread if GetAdaptionState() is called. It is guarded by feature_pipeline_mutex_.
OnlineNnet2FeaturePipeline feature_pipeline_;
Mutex feature_pipeline_mutex_;
// The following three variables are guarded by feature_synchronizer_; they are
// a buffer of features that is passed between the feature-processing thread
// and the neural-net evaluation thread. feature_buffer_start_frame_ is the
// frame index of the first feature vector in the buffer. After the
// neural-net evaluation thread reads the features, it removes them from the
// vector "feature_buffer_" and advances "feature_buffer_start_frame_"
// appropriately. The feature-processing thread does not advance
// feature_buffer_start_frame_, but appends to feature_buffer_. After
// all features are done (because user called InputFinished()), the
// feature-processing thread sets feature_buffer_finished_.
int32 feature_buffer_start_frame_;
bool feature_buffer_finished_;
std::vector<Vector<BaseFloat>* > feature_buffer_;
ThreadSynchronizer feature_synchronizer_;
// this Decodable object just stores a matrix of scaled log-likelihoods
// obtained by the nnet-evaluation thread. It is produced by the
// nnet-evaluation thread and consumed by the decoder-search thread. The
// decoding thread sets num_frames_decoded_ so the nnet-evaluation thread
// knows which frames of log-likelihoods it can discard. Both of these
// variables are guarded by decodable_synchronizer_. Note:
// the num_frames_decoded_ may be less than the current number of frames
// the decoder has decoded; the decoder thread sets this variable when it
// locks this mutex.
DecodableMatrixMappedOffset decodable_;
int32 num_frames_decoded_;
ThreadSynchronizer decodable_synchronizer_;
// the decoder_ object contains everything related to the graph search.
LatticeFasterOnlineDecoder decoder_;
// decoder_mutex_ guards the decoder_ object. It is usually held by the decoding
// thread (where it is released and re-obtained on each frame), but is obtained
// by the main (parent) thread if you call functions like NumFramesDecoded(),
// GetLattice() and GetBestPath().
Mutex decoder_mutex_;
// This contains the thread pointers for the feature-extraction,
// nnet-evaluation, and decoder-search threads respectively (or NULL if they
// have been joined in Wait()).
pthread_t threads_[3];
// This is set to true if AbortAllThreads was called for any reason, including
// if someone called TerminateDecoding().
bool abort_;
// This is set to true if any kind of unexpected error is encountered,
// including if exceptions are raised in any of the threads. Will normally
// be a coding error, malloc failure-- something we should never encounter.
bool error_;
};
/// @} End of "addtogroup onlinedecoding"
} // namespace kaldi
#endif // KALDI_ONLINE2_ONLINE_NNET2_DECODING_THREADED_H_

Просмотреть файл

@ -33,7 +33,7 @@ SingleUtteranceNnet2Decoder::SingleUtteranceNnet2Decoder(
feature_pipeline_(feature_pipeline),
tmodel_(tmodel),
decodable_(model, tmodel, config.decodable_opts, feature_pipeline),
decoder_(fst, config.faster_decoder_opts) {
decoder_(fst, config.decoder_opts) {
decoder_.InitDecoding();
}
@ -56,12 +56,12 @@ void SingleUtteranceNnet2Decoder::GetLattice(bool end_of_utterance,
Lattice raw_lat;
decoder_.GetRawLattice(&raw_lat, end_of_utterance);
if (!config_.faster_decoder_opts.determinize_lattice)
if (!config_.decoder_opts.determinize_lattice)
KALDI_ERR << "--determinize-lattice=false option is not supported at the moment";
BaseFloat lat_beam = config_.faster_decoder_opts.lattice_beam;
BaseFloat lat_beam = config_.decoder_opts.lattice_beam;
DeterminizeLatticePhonePrunedWrapper(
tmodel_, &raw_lat, lat_beam, clat, config_.faster_decoder_opts.det_opts);
tmodel_, &raw_lat, lat_beam, clat, config_.decoder_opts.det_opts);
}
void SingleUtteranceNnet2Decoder::GetBestPath(bool end_of_utterance,

Просмотреть файл

@ -49,13 +49,13 @@ namespace kaldi {
// here: namely, OnlineNnet2FeaturePipelineConfig and OnlineEndpointConfig.
struct OnlineNnet2DecodingConfig {
LatticeFasterDecoderConfig faster_decoder_opts;
LatticeFasterDecoderConfig decoder_opts;
nnet2::DecodableNnet2OnlineOptions decodable_opts;
OnlineNnet2DecodingConfig() { decodable_opts.acoustic_scale = 0.1; }
void Register(OptionsItf *po) {
faster_decoder_opts.Register(po);
decoder_opts.Register(po);
decodable_opts.Register(po);
}
};
@ -77,8 +77,9 @@ class SingleUtteranceNnet2Decoder {
/// advance the decoding as far as we can.
void AdvanceDecoding();
/// Finalizes the decoding. Cleanups and prunes remaining tokens, so the final
/// result is faster to obtain.
/// Finalizes the decoding. Cleans up and prunes remaining tokens, so the
/// GetLattice() call will return faster. You must not call this before
/// calling (TerminateDecoding() or InputIsFinished()) and then Wait().
void FinalizeDecoding();
int32 NumFramesDecoded() const;
@ -98,13 +99,6 @@ class SingleUtteranceNnet2Decoder {
void GetBestPath(bool end_of_utterance,
Lattice *best_path) const;
/// This function outputs to "final_relative_cost", if non-NULL, a number >= 0
/// that will be close to zero if the final-probs were close to the best probs
/// active on the final frame. (the output to final_relative_cost is based on
/// the first-pass decoding). If it's close to zero (e.g. < 5, as a guess),
/// it means you reached the end of the grammar with good probability, which
/// can be taken as a good sign that the input was OK.
BaseFloat FinalRelativeCost() { return decoder_.FinalRelativeCost(); }
/// This function calls EndpointDetected from online-endpoint.h,
/// with the required arguments.

Просмотреть файл

@ -184,4 +184,5 @@ BaseFloat OnlineNnet2FeaturePipelineInfo::FrameShiftInSeconds() const {
}
}
} // namespace kaldi

Просмотреть файл

@ -203,9 +203,7 @@ class OnlineNnet2FeaturePipeline: public OnlineFeatureInterface {
void AcceptWaveform(BaseFloat sampling_rate,
const VectorBase<BaseFloat> &waveform);
BaseFloat FrameShiftInSeconds() const {
return info_.FrameShiftInSeconds();
}
BaseFloat FrameShiftInSeconds() const { return info_.FrameShiftInSeconds(); }
/// If you call InputFinished(), it tells the class you won't be providing any
/// more waveform. This will help flush out the last few frames of delta or

Просмотреть файл

@ -36,7 +36,11 @@ void OnlineTimingStats::Print(bool online){
KALDI_LOG << "Timing stats: real-time factor was " << real_time_factor
<< " (note: this cannot be less than one.)";
KALDI_LOG << "Average delay was " << average_wait << " seconds.";
if (idle_percent != 0.0) {
// If the user was calling SleepUntil instead of WaitUntil, this will
// always be zero; so don't print it in that case.
KALDI_LOG << "Percentage of time spent idling was " << idle_percent;
}
KALDI_LOG << "Longest delay was " << max_delay_ << " seconds for utterance "
<< '\'' << max_delay_utt_ << '\'';
} else {
@ -74,6 +78,17 @@ void OnlineTimer::WaitUntil(double cur_utterance_length) {
utterance_length_ = cur_utterance_length;
}
void OnlineTimer::SleepUntil(double cur_utterance_length) {
KALDI_ASSERT(waited_ == 0 && "Do not mix SleepUntil with WaitUntil.");
double elapsed = timer_.Elapsed();
double to_wait = cur_utterance_length - elapsed;
if (to_wait > 0.0) {
Sleep(to_wait);
}
utterance_length_ = cur_utterance_length;
}
double OnlineTimer::Elapsed() {
return timer_.Elapsed() + waited_;
}
@ -88,6 +103,9 @@ void OnlineTimer::OutputStats(OnlineTimingStats *stats) {
KALDI_WARN << "Negative wait time " << wait_time
<< " does not make sense.";
}
KALDI_VLOG(2) << "Latency " << wait_time << " seconds out of "
<< utterance_length_ << ", for utterance "
<< utterance_id_;
stats->num_utts_++;
stats->total_audio_ += utterance_length_;

Просмотреть файл

@ -50,8 +50,10 @@ class OnlineTimingStats {
int32 num_utts_;
// all times are in seconds.
double total_audio_; // total time of audio.
double total_time_taken_;
double total_time_waited_; // total time in wait() state.
double total_time_taken_; // total time spent processing the audio.
double total_time_waited_; // total time we pretended to wait (but just
// increased the waited_ counter)... zero if you
// called SleepUntil instead of WaitUntil().
double max_delay_; // maximum delay at utterance end.
std::string max_delay_utt_;
};
@ -65,7 +67,8 @@ class OnlineTimingStats {
/// available in a real-time application-- e.g. say we need to process a chunk
/// that ends half a second into the utterance, we would sleep until half a
/// second had elapsed since the start of the utterance. In this code we
/// don't actually sleep; we simulate the effect of sleeping by just incrementing
/// have the option to not actually sleep: we can simulate the effect of
/// sleeping by just incrementing
/// a variable that says how long we would have slept; and we add this to
/// wall-clock times obtained from Timer::Elapsed().
/// The usage of this class will be something like as follows:
@ -86,8 +89,14 @@ class OnlineTimer {
public:
OnlineTimer(const std::string &utterance_id);
/// The call to WaitUntil(t) simulates the effect of waiting
/// until t seconds after this object was initialized.
/// The call to SleepUntil(t) will sleep until cur_utterance_length seconds
/// after this object was initialized, or return immediately if we've
/// already passed that time.
void SleepUntil(double cur_utterance_length);
/// The call to WaitUntil(t) simulates the effect of sleeping until
/// cur_utterance_length seconds after this object was initialized;
/// but instead of actually sleeping, it increases a counter.
void WaitUntil(double cur_utterance_length);
/// This call, which should be made after decoding is done,

Просмотреть файл

@ -10,7 +10,7 @@ BINFILES = online2-wav-gmm-latgen-faster apply-cmvn-online \
extend-wav-with-silence compress-uncompress-speex \
online2-wav-nnet2-latgen-faster ivector-extract-online2 \
online2-wav-dump-features ivector-randomize \
online2-wav-nnet2-am-compute
online2-wav-nnet2-am-compute online2-wav-nnet2-latgen-threaded
OBJFILES =

Просмотреть файл

@ -197,7 +197,9 @@ void FindQuietestSegment(const Vector<BaseFloat> &wav_in,
start += seg_shift;
}
KALDI_ASSERT(min_energy > 0.0);
if (min_energy == 0.0) {
KALDI_WARN << "Zero energy silence being used.";
}
*wav_sil = wav_min_energy;
}

Просмотреть файл

@ -88,7 +88,8 @@ int main(int argc, char *argv[]) {
"<spk2utt-rspecifier> <wav-rspecifier> <lattice-wspecifier>\n"
"The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n"
"you want to decode utterance by utterance.\n"
"See egs/rm/s5/local/run_online_decoding_nnet2.sh for example\n";
"See egs/rm/s5/local/run_online_decoding_nnet2.sh for example\n"
"See also online2-wav-nnet2-latgen-threaded\n";
ParseOptions po(usage);

Просмотреть файл

@ -0,0 +1,277 @@
// onlinebin/online2-wav-nnet2-latgen-thread.cc
// Copyright 2014-2015 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "feat/wave-reader.h"
#include "online2/online-nnet2-decoding-threaded.h"
#include "online2/onlinebin-util.h"
#include "online2/online-timing.h"
#include "online2/online-endpoint.h"
#include "fstext/fstext-lib.h"
#include "lat/lattice-functions.h"
namespace kaldi {
void GetDiagnosticsAndPrintOutput(const std::string &utt,
const fst::SymbolTable *word_syms,
const CompactLattice &clat,
int64 *tot_num_frames,
double *tot_like) {
if (clat.NumStates() == 0) {
KALDI_WARN << "Empty lattice.";
return;
}
CompactLattice best_path_clat;
CompactLatticeShortestPath(clat, &best_path_clat);
Lattice best_path_lat;
ConvertLattice(best_path_clat, &best_path_lat);
double likelihood;
LatticeWeight weight;
int32 num_frames;
std::vector<int32> alignment;
std::vector<int32> words;
GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight);
num_frames = alignment.size();
likelihood = -(weight.Value1() + weight.Value2());
*tot_num_frames += num_frames;
*tot_like += likelihood;
KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is "
<< (likelihood / num_frames) << " over " << num_frames
<< " frames.";
if (word_syms != NULL) {
std::cerr << utt << ' ';
for (size_t i = 0; i < words.size(); i++) {
std::string s = word_syms->Find(words[i]);
if (s == "")
KALDI_ERR << "Word-id " << words[i] << " not in symbol table.";
std::cerr << s << ' ';
}
std::cerr << std::endl;
}
}
}
int main(int argc, char *argv[]) {
try {
using namespace kaldi;
using namespace fst;
typedef kaldi::int32 int32;
typedef kaldi::int64 int64;
const char *usage =
"Reads in wav file(s) and simulates online decoding with neural nets\n"
"(nnet2 setup), with optional iVector-based speaker adaptation and\n"
"optional endpointing. This version uses multiple threads for decoding.\n"
"Note: some configuration values and inputs are set via config files\n"
"whose filenames are passed as options\n"
"\n"
"Usage: online2-wav-nnet2-latgen-threaded [options] <nnet2-in> <fst-in> "
"<spk2utt-rspecifier> <wav-rspecifier> <lattice-wspecifier>\n"
"The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if\n"
"you want to decode utterance by utterance.\n"
"See egs/rm/s5/local/run_online_decoding_nnet2.sh for example\n"
"See also online2-wav-nnet2-latgen-faster\n";
ParseOptions po(usage);
std::string word_syms_rxfilename;
OnlineEndpointConfig endpoint_config;
// feature_config includes configuration for the iVector adaptation,
// as well as the basic features.
OnlineNnet2FeaturePipelineConfig feature_config;
OnlineNnet2DecodingThreadedConfig nnet2_decoding_config;
BaseFloat chunk_length_secs = 0.05;
bool do_endpointing = false;
bool modify_ivector_config = false;
po.Register("chunk-length", &chunk_length_secs,
"Length of chunk size in seconds, that we provide each time to the "
"decoder. The actual chunk sizes it processes for various stages "
"of decoding are dynamically determinated, and unrelated to this");
po.Register("word-symbol-table", &word_syms_rxfilename,
"Symbol table for words [for debug output]");
po.Register("do-endpointing", &do_endpointing,
"If true, apply endpoint detection");
po.Register("modify-ivector-config", &modify_ivector_config,
"If true, modifies the iVector configuration from the config files "
"by setting --use-most-recent-ivector=true and --greedy-ivector-extractor=true. "
"This will give the best possible results, but the results may become dependent "
"on the speed of your machine (slower machine -> better results). Compare "
"to the --online option in online2-wav-nnet-latgen-faster");
feature_config.Register(&po);
nnet2_decoding_config.Register(&po);
endpoint_config.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() != 5) {
po.PrintUsage();
return 1;
}
std::string nnet2_rxfilename = po.GetArg(1),
fst_rxfilename = po.GetArg(2),
spk2utt_rspecifier = po.GetArg(3),
wav_rspecifier = po.GetArg(4),
clat_wspecifier = po.GetArg(5);
OnlineNnet2FeaturePipelineInfo feature_info(feature_config);
if (modify_ivector_config) {
feature_info.ivector_extractor_info.use_most_recent_ivector = true;
feature_info.ivector_extractor_info.greedy_ivector_extractor = true;
}
TransitionModel trans_model;
nnet2::AmNnet am_nnet;
{
bool binary;
Input ki(nnet2_rxfilename, &binary);
trans_model.Read(ki.Stream(), binary);
am_nnet.Read(ki.Stream(), binary);
}
fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldi(fst_rxfilename);
fst::SymbolTable *word_syms = NULL;
if (word_syms_rxfilename != "")
if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename)))
KALDI_ERR << "Could not read symbol table from file "
<< word_syms_rxfilename;
int32 num_done = 0, num_err = 0;
double tot_like = 0.0;
int64 num_frames = 0;
SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
RandomAccessTableReader<WaveHolder> wav_reader(wav_rspecifier);
CompactLatticeWriter clat_writer(clat_wspecifier);
OnlineTimingStats timing_stats;
for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
std::string spk = spk2utt_reader.Key();
const std::vector<std::string> &uttlist = spk2utt_reader.Value();
OnlineIvectorExtractorAdaptationState adaptation_state(
feature_info.ivector_extractor_info);
for (size_t i = 0; i < uttlist.size(); i++) {
std::string utt = uttlist[i];
if (!wav_reader.HasKey(utt)) {
KALDI_WARN << "Did not find audio for utterance " << utt;
num_err++;
continue;
}
const WaveData &wave_data = wav_reader.Value(utt);
// get the data for channel zero (if the signal is not mono, we only
// take the first channel).
SubVector<BaseFloat> data(wave_data.Data(), 0);
SingleUtteranceNnet2DecoderThreaded decoder(
nnet2_decoding_config, trans_model, am_nnet,
*decode_fst, feature_info, adaptation_state);
OnlineTimer decoding_timer(utt);
BaseFloat samp_freq = wave_data.SampFreq();
int32 chunk_length;
KALDI_ASSERT(chunk_length_secs > 0);
chunk_length = int32(samp_freq * chunk_length_secs);
if (chunk_length == 0) chunk_length = 1;
int32 samp_offset = 0;
while (samp_offset < data.Dim()) {
int32 samp_remaining = data.Dim() - samp_offset;
int32 num_samp = chunk_length < samp_remaining ? chunk_length
: samp_remaining;
SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
decoder.AcceptWaveform(samp_freq, wave_part);
samp_offset += num_samp;
// Note: the next call may actually call sleep().
decoding_timer.SleepUntil(samp_offset / samp_freq);
if (samp_offset == data.Dim()) {
// no more input. flush out last frames
decoder.InputFinished();
}
if (do_endpointing && decoder.EndpointDetected(endpoint_config)) {
decoder.TerminateDecoding();
break;
}
}
Timer timer;
decoder.Wait();
KALDI_VLOG(1) << "Waited " << timer.Elapsed() << " seconds for decoder to "
<< "finish after giving it last chunk.";
decoder.FinalizeDecoding();
CompactLattice clat;
bool end_of_utterance = true;
decoder.GetLattice(end_of_utterance, &clat, NULL);
GetDiagnosticsAndPrintOutput(utt, word_syms, clat,
&num_frames, &tot_like);
decoding_timer.OutputStats(&timing_stats);
// In an application you might avoid updating the adaptation state if
// you felt the utterance had low confidence. See lat/confidence.h
decoder.GetAdaptationState(&adaptation_state);
// we want to output the lattice with un-scaled acoustics.
BaseFloat inv_acoustic_scale =
1.0 / nnet2_decoding_config.acoustic_scale;
ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat);
KALDI_VLOG(1) << "Adding the various end-of-utterance tasks took the "
<< "total latency to " << timer.Elapsed() << " seconds.";
clat_writer.Write(utt, clat);
KALDI_LOG << "Decoded utterance " << utt;
num_done++;
}
}
bool online = true;
timing_stats.Print(online);
KALDI_LOG << "Decoded " << num_done << " utterances, "
<< num_err << " with errors.";
KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames)
<< " per frame over " << num_frames << " frames.";
delete decode_fst;
delete word_syms; // will delete if non-NULL.
return (num_done != 0 ? 0 : 1);
} catch(const std::exception& e) {
std::cerr << e.what();
return -1;
}
} // main()

Просмотреть файл

@ -46,7 +46,8 @@ Mutex::~Mutex() {
<< "a known issue that affects Haswell processors, see "
<< "https://sourceware.org/bugzilla/show_bug.cgi?id=16657 "
<< "If your processor is not Haswell and you see this message, "
<< "it could be a bug in Kaldi.";
<< "it could be a bug in Kaldi. However it could be that "
<< "multi-threaded code terminated messily.";
}
}
}

Просмотреть файл

@ -2010,6 +2010,7 @@ template<class Holder> class RandomAccessTableReaderUnsortedArchiveImpl:
if (!pr.second) { // Was not inserted-- previous element w/ same key
delete holder_; // map was not changed, no ownership transferred.
holder_ = NULL;
KALDI_ERR << "Error in RandomAccessTableReader: duplicate key "
<< cur_key_ << " in archive " << archive_rxfilename_;
}