зеркало из https://github.com/mozilla/kaldi.git
trunk: updating online-nnet2-decoding setup to allow for downweighting of silence in the stats for iVector estimation.
git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@4972 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
This commit is contained in:
Родитель
9045a18f63
Коммит
0ddf4bba08
|
@ -168,7 +168,7 @@ if [ $stage -le 13 ]; then
|
|||
done
|
||||
fi
|
||||
|
||||
#exit 0;
|
||||
exit 0;
|
||||
###### Comment out the "exit 0" above to run the multi-threaded decoding. #####
|
||||
|
||||
if [ $stage -le 14 ]; then
|
||||
|
@ -189,4 +189,14 @@ if [ $stage -le 15 ]; then
|
|||
${dir}_online/decode_pp_${test}_tgsmall_utt_threaded_ep || exit 1;
|
||||
fi
|
||||
|
||||
if [ $stage -le 16 ]; then
|
||||
# Demonstrate the multi-threaded decoding with silence excluded
|
||||
# from iVector estimation.
|
||||
test=dev_clean
|
||||
steps/online/nnet2/decode.sh --threaded true --silence-weight 0.0 \
|
||||
--config conf/decode.config --cmd "$decode_cmd" --nj 30 \
|
||||
--per-utt true exp/tri6b/graph_pp_tgsmall data/$test \
|
||||
${dir}_online/decode_pp_${test}_tgsmall_utt_threaded_sil0.0 || exit 1;
|
||||
fi
|
||||
|
||||
exit 0;
|
||||
|
|
|
@ -66,7 +66,7 @@ for f in $graphdir/HCLG.fst $data/feats.scp $model $extra_files; do
|
|||
done
|
||||
|
||||
sdata=$data/split$nj;
|
||||
cmvn_opts=`cat $srcdir/cmvn_opts 2>/dev/null`
|
||||
cmvn_opts=`cat $srcdir/cmvn_opts` || exit 1;
|
||||
thread_string=
|
||||
[ $num_threads -gt 1 ] && thread_string="-parallel --num-threads=$num_threads"
|
||||
|
||||
|
|
|
@ -20,6 +20,10 @@ do_endpointing=false
|
|||
do_speex_compressing=false
|
||||
scoring_opts=
|
||||
skip_scoring=false
|
||||
silence_weight=1.0 # set this to a value less than 1 (e.g. 0) to enable silence weighting.
|
||||
max_state_duration=40 # This only has an effect if you are doing silence
|
||||
# weighting. This default is probably reasonable. transition-ids repeated
|
||||
# more than this many times in an alignment are treated as silence.
|
||||
iter=final
|
||||
# End configuration section.
|
||||
|
||||
|
@ -94,6 +98,12 @@ if $do_endpointing; then
|
|||
wav_rspecifier="$wav_rspecifier extend-wav-with-silence ark:- ark:- |"
|
||||
fi
|
||||
|
||||
if [ "$silence_weight" != "1.0" ]; then
|
||||
silphones=$(cat $graphdir/phones/silence.csl) || exit 1
|
||||
silence_weighting_opts="--ivector-silence-weighting.max-state-duration=$max_state_duration --ivector-silence-weighting.silence_phones=$silphones --ivector-silence-weighting.silence-weight=$silence_weight"
|
||||
else
|
||||
silence_weighting_opts=
|
||||
fi
|
||||
|
||||
|
||||
if $threaded; then
|
||||
|
@ -110,7 +120,7 @@ fi
|
|||
|
||||
if [ $stage -le 0 ]; then
|
||||
$cmd $parallel_opts JOB=1:$nj $dir/log/decode.JOB.log \
|
||||
$decoder $opts --do-endpointing=$do_endpointing \
|
||||
$decoder $opts $silence_weighting_opts --do-endpointing=$do_endpointing \
|
||||
--config=$srcdir/conf/online_nnet2_decoding.conf \
|
||||
--max-active=$max_active --beam=$beam --lattice-beam=$lattice_beam \
|
||||
--acoustic-scale=$acwt --word-symbol-table=$graphdir/words.txt \
|
||||
|
|
|
@ -22,8 +22,10 @@ if [ "$1" == "--per-utt" ]; then
|
|||
fi
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: split_data.sh <data-dir> <num-to-split>"
|
||||
echo "Usage: split_data.sh [--per-utt] <data-dir> <num-to-split>"
|
||||
echo "This script will not split the data-dir if it detects that the output is newer than the input."
|
||||
echo "By default it splits per speaker (so each speaker is in only one split dir),"
|
||||
echo "but with the --per-utt option it will ignore the speaker information while splitting."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -45,13 +47,11 @@ nu=`cat $data/utt2spk | wc -l`
|
|||
nf=`cat $data/feats.scp 2>/dev/null | wc -l`
|
||||
nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file
|
||||
if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then
|
||||
echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); this script "
|
||||
echo "** may produce incorrectly split data."
|
||||
echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can "
|
||||
echo "** use utils/fix_data_dir.sh $data to fix this."
|
||||
fi
|
||||
if [ -f $data/text ] && [ $nu -ne $nt ]; then
|
||||
echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); this script "
|
||||
echo "** may produce incorrectly split data."
|
||||
echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can "
|
||||
echo "** use utils/fix_data_dir.sh to fix this."
|
||||
fi
|
||||
|
||||
|
@ -74,11 +74,7 @@ fi
|
|||
|
||||
for n in `seq $numsplit`; do
|
||||
mkdir -p $data/split$numsplit/$n
|
||||
feats="$feats $data/split$numsplit/$n/feats.scp"
|
||||
vads="$vads $data/split$numsplit/$n/vad.scp"
|
||||
texts="$texts $data/split$numsplit/$n/text"
|
||||
utt2spks="$utt2spks $data/split$numsplit/$n/utt2spk"
|
||||
utt2langs="$utt2langs $data/split$numsplit/$n/utt2lang"
|
||||
done
|
||||
|
||||
if $split_per_spk; then
|
||||
|
@ -87,37 +83,51 @@ else
|
|||
utt2spk_opt=
|
||||
fi
|
||||
|
||||
utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1
|
||||
|
||||
[ -f $data/feats.scp ] && utils/split_scp.pl $utt2spk_opt $data/feats.scp $feats
|
||||
|
||||
[ -f $data/text ] && utils/split_scp.pl $utt2spk_opt $data/text $texts
|
||||
|
||||
[ -f $data/vad.scp ] && utils/split_scp.pl $utt2spk_opt $data/vad.scp $vads
|
||||
|
||||
[ -f $data/utt2lang ] && utils/split_scp.pl $utt2spk_opt $data/utt2lang $utt2langs
|
||||
|
||||
# If lockfile is not installed, just don't lock it. It's not a big deal.
|
||||
which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock
|
||||
|
||||
utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1
|
||||
|
||||
for n in `seq $numsplit`; do
|
||||
dsn=$data/split$numsplit/$n
|
||||
utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
|
||||
done
|
||||
|
||||
maybe_wav_scp=
|
||||
if [ ! -f $data/segments ]; then
|
||||
maybe_wav_scp=wav.scp # If there is no segments file, then wav file is
|
||||
# indexed per utt.
|
||||
fi
|
||||
|
||||
# split some things that are indexed by utterance.
|
||||
for f in feats.scp text vad.scp utt2lang $maybe_wav_scp; do
|
||||
if [ -f $data/$f ]; then
|
||||
utils/filter_scps.pl JOB=1:$numsplit \
|
||||
$data/split$numsplit/JOB/utt2spk $data/$f $data/split$numsplit/JOB/$f || exit 1;
|
||||
fi
|
||||
done
|
||||
|
||||
# split some things that are indexed by speaker
|
||||
for f in spk2gender spk2warp cmvn.scp; do
|
||||
if [ -f $data/$f ]; then
|
||||
utils/filter_scps.pl JOB=1:$numsplit \
|
||||
$data/split$numsplit/JOB/spk2utt $data/$f $data/split$numsplit/JOB/$f || exit 1;
|
||||
fi
|
||||
done
|
||||
|
||||
for n in `seq $numsplit`; do
|
||||
dsn=$data/split$numsplit/$n
|
||||
utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
|
||||
for f in spk2gender spk2warp cmvn.scp; do
|
||||
[ -f $data/$f ] && \
|
||||
utils/filter_scp.pl $dsn/spk2utt $data/$f > $dsn/$f
|
||||
done
|
||||
if [ -f $data/segments ]; then
|
||||
utils/filter_scp.pl $dsn/utt2spk $data/segments > $dsn/segments
|
||||
awk '{print $2;}' $dsn/segments |sort|uniq > $data/tmp.reco # recording-ids.
|
||||
[ -f $data/reco2file_and_channel ] &&
|
||||
utils/filter_scp.pl $data/tmp.reco $data/reco2file_and_channel > $dsn/reco2file_and_channel
|
||||
[ -f $data/wav.scp ] && utils/filter_scp.pl $data/tmp.reco $data/wav.scp > $dsn/wav.scp
|
||||
awk '{print $2;}' $dsn/segments | sort | uniq > $data/tmp.reco # recording-ids.
|
||||
if [ -f $data/reco2file_and_channel ]; then
|
||||
utils/filter_scp.pl $data/tmp.reco $data/reco2file_and_channel > $dsn/reco2file_and_channel
|
||||
fi
|
||||
if [ -f $data/wav.scp ]; then
|
||||
utils/filter_scp.pl $data/tmp.reco $data/wav.scp >$dsn/wav.scp
|
||||
fi
|
||||
rm $data/tmp.reco
|
||||
else # else wav indexed by utterance -> filter on this.
|
||||
[ -f $data/wav.scp ] &&
|
||||
utils/filter_scp.pl $dsn/utt2spk $data/wav.scp > $dsn/wav.scp
|
||||
fi
|
||||
fi # else it would have been handled above, see maybe_wav.
|
||||
done
|
||||
|
||||
rm -f $data/.split_lock
|
||||
|
|
|
@ -669,13 +669,13 @@ HTML_HEADER = doc/header.html
|
|||
# 180 is cyan, 240 is blue, 300 purple, and 360 is red again.
|
||||
# The allowed range is 0 to 359.
|
||||
|
||||
HTML_COLORSTYLE_HUE = 26
|
||||
HTML_COLORSTYLE_HUE = 31
|
||||
|
||||
# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of
|
||||
# the colors in the HTML output. For a value of 0 the output will use
|
||||
# grayscales only. A value of 255 will produce the most vivid colors.
|
||||
|
||||
HTML_COLORSTYLE_SAT = 80
|
||||
HTML_COLORSTYLE_SAT = 115
|
||||
|
||||
# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to
|
||||
# the luminance component of the colors in the HTML output. Values below
|
||||
|
@ -684,7 +684,7 @@ HTML_COLORSTYLE_SAT = 80
|
|||
# so 80 represents a gamma of 0.8, The value 220 represents a gamma of 2.2,
|
||||
# and 100 does not change the gamma.
|
||||
|
||||
HTML_COLORSTYLE_GAMMA = 90
|
||||
HTML_COLORSTYLE_GAMMA = 80
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -50,9 +50,15 @@ class LatticeFasterOnlineDecoder {
|
|||
typedef Arc::Label Label;
|
||||
typedef Arc::StateId StateId;
|
||||
typedef Arc::Weight Weight;
|
||||
|
||||
struct BestPathIterator {
|
||||
void *tok;
|
||||
int32 frame;
|
||||
// note, "frame" is the frame-index of the frame you'll get the
|
||||
// transition-id for next time, if you call TraceBackBestPath on this
|
||||
// iterator (assuming it's not an epsilon transition). Note that this
|
||||
// is one less than you might reasonably expect, e.g. it's -1 for
|
||||
// the nonemitting transitions before the first frame.
|
||||
BestPathIterator(void *t, int32 f): tok(t), frame(f) { }
|
||||
bool Done() { return tok == NULL; }
|
||||
};
|
||||
|
|
|
@ -56,4 +56,6 @@ fi
|
|||
# moved the header.html to doc/ and edited it to include the following snippet,
|
||||
# and added it to the repo.
|
||||
#<link rel="icon" type="image/png" href="http://kaldi.sf.net/favicon.ico">
|
||||
# Also did similar with stylesheet.
|
||||
|
||||
|
||||
|
|
|
@ -49,9 +49,9 @@ namespace kaldi {
|
|||
be run in order to build the systems used for alignment.
|
||||
|
||||
Regarding which of the two setups you should use:
|
||||
- Karel's setup (nnet1) supports training on a single GPU card, which allows
|
||||
- Karel's setup (\ref dnn1 "nnet1") supports training on a single GPU card, which allows
|
||||
the implementation to be simpler and relatively easy to modify.
|
||||
- Dan's setup (nnet2) is more flexible in how
|
||||
- Dan's setup (\ref dnn2 "nnet2") is more flexible in how
|
||||
you can train: it supports using multiple GPUs, or multiple CPU's each with
|
||||
multiple threads. Multiple GPU's is the recommended setup.
|
||||
They don't have to all be on the same machine. Both setups give commensurate results.
|
||||
|
|
|
@ -23,7 +23,7 @@ $mathjax
|
|||
<tbody>
|
||||
<tr style="height: 56px;">
|
||||
<!--BEGIN PROJECT_LOGO-->
|
||||
<td id="projectlogo"><img alt="Logo" src="$relpath$$projectlogo"/ style="padding: 4px 5px 1px 5px"></td>
|
||||
<td id="projectlogo"><img alt="Logo" src="$relpath$$projectlogo"/ style="padding: 3px 5px 1px 5px"></td>
|
||||
<!--END PROJECT_LOGO-->
|
||||
<!--BEGIN PROJECT_NAME-->
|
||||
<td style="padding-left: 0.5em;">
|
||||
|
|
|
@ -534,7 +534,9 @@ void OnlineIvectorEstimationStats::AccStats(
|
|||
for (size_t idx = 0; idx < gauss_post.size(); idx++) {
|
||||
int32 g = gauss_post[idx].first;
|
||||
double weight = gauss_post[idx].second;
|
||||
KALDI_ASSERT(weight >= 0.0);
|
||||
// allow negative weights; it's needed in the online iVector extraction
|
||||
// with speech-silence detection based on decoder traceback (we subtract
|
||||
// stuff we previously added if the traceback changes).
|
||||
if (weight == 0.0)
|
||||
continue;
|
||||
linear_term_.AddMatVec(weight, extractor.Sigma_inv_M_[g], kTrans,
|
||||
|
@ -543,8 +545,9 @@ void OnlineIvectorEstimationStats::AccStats(
|
|||
quadratic_term_vec.AddVec(weight, U_g);
|
||||
tot_weight += weight;
|
||||
}
|
||||
if (max_count_ != 0.0) {
|
||||
// see comments in header RE max_count for explanation.
|
||||
if (max_count_ > 0.0) {
|
||||
// see comments in header RE max_count for explanation. It relates to
|
||||
// prior scaling when the count exceeds max_count_
|
||||
double old_num_frames = num_frames_,
|
||||
new_num_frames = num_frames_ + tot_weight;
|
||||
double old_prior_scale = std::max(old_num_frames, max_count_) / max_count_,
|
||||
|
|
|
@ -538,7 +538,13 @@ int32 LinearCgd(const LinearCgdOptions &opts,
|
|||
p.AddVec(-1.0, r);
|
||||
r_cur_norm_sq = r_next_norm_sq;
|
||||
}
|
||||
if (r_cur_norm_sq > r_initial_norm_sq) {
|
||||
|
||||
// note: the first element of the && is only there to save compute.
|
||||
// the residual r is A x - b, and r_cur_norm_sq and r_initial_norm_sq are
|
||||
// of the form r * r, so it's clear that b * b has the right dimension to
|
||||
// compare with the residual.
|
||||
if (r_cur_norm_sq > r_initial_norm_sq &&
|
||||
r_cur_norm_sq > r_initial_norm_sq + 1.0e-10 * VecVec(b, b)) {
|
||||
KALDI_WARN << "Doing linear CGD in dimension " << A.NumRows() << ", after " << k
|
||||
<< " iterations the squared residual has got worse, "
|
||||
<< r_cur_norm_sq << " > " << r_initial_norm_sq
|
||||
|
|
|
@ -895,6 +895,9 @@ class AffineComponentPreconditioned: public AffineComponent {
|
|||
};
|
||||
|
||||
|
||||
/// Keywords: natural gradient descent, NG-SGD, naturalgradient. For
|
||||
/// the top-level of the natural gradient code look here, and also in
|
||||
/// nnet-precondition-online.h.
|
||||
/// AffineComponentPreconditionedOnline is, like AffineComponentPreconditioned,
|
||||
/// a version of AffineComponent that has a non-(multiple of unit) learning-rate
|
||||
/// matrix. See nnet-precondition-online.h for a description of the technique.
|
||||
|
|
|
@ -32,6 +32,8 @@ namespace nnet2 {
|
|||
|
||||
|
||||
/**
|
||||
Keywords for search: natural gradient, naturalgradient, NG-SGD
|
||||
|
||||
It will help to first try to understand ./nnet-precondition.h before reading
|
||||
this comment and trying to understand what's going on here. The motivation
|
||||
for this method was that the code in nnet-precondition.h was too slow when
|
||||
|
|
|
@ -33,6 +33,7 @@ void OnlineIvectorExtractionInfo::Init(
|
|||
min_post = config.min_post;
|
||||
posterior_scale = config.posterior_scale;
|
||||
max_count = config.max_count;
|
||||
num_cg_iters = config.num_cg_iters;
|
||||
use_most_recent_ivector = config.use_most_recent_ivector;
|
||||
greedy_ivector_extractor = config.greedy_ivector_extractor;
|
||||
if (greedy_ivector_extractor && !use_most_recent_ivector) {
|
||||
|
@ -151,31 +152,54 @@ int32 OnlineIvectorFeature::NumFramesReady() const {
|
|||
return lda_->NumFramesReady();
|
||||
}
|
||||
|
||||
void OnlineIvectorFeature::UpdateStatsUntilFrame(int32 frame) {
|
||||
KALDI_ASSERT(frame >= 0 && frame < this->NumFramesReady());
|
||||
void OnlineIvectorFeature::UpdateFrameWeights(
|
||||
const std::vector<std::pair<int32, BaseFloat> > &delta_weights) {
|
||||
// add the elements to delta_weights_, which is a priority queue. The top
|
||||
// element of the priority queue is the lowest numbered frame (we ensured this
|
||||
// by making the comparison object std::greater instead of std::less). Adding
|
||||
// elements from top (lower-numbered frames) to bottom (higher-numbered
|
||||
// frames) should be most efficient, assuming it's a heap internally. So we
|
||||
// go forward not backward in delta_weights while adding.
|
||||
int32 num_frames_ready = NumFramesReady();
|
||||
for (size_t i = 0; i < delta_weights.size(); i++) {
|
||||
delta_weights_.push(delta_weights[i]);
|
||||
int32 frame = delta_weights[i].first;
|
||||
KALDI_ASSERT(frame >= 0 && frame < num_frames_ready);
|
||||
if (frame > most_recent_frame_with_weight_)
|
||||
most_recent_frame_with_weight_ = frame;
|
||||
}
|
||||
delta_weights_provided_ = true;
|
||||
}
|
||||
|
||||
int32 feat_dim = lda_normalized_->Dim(),
|
||||
ivector_period = info_.ivector_period;
|
||||
|
||||
int32 num_cg_iters = 15; // I don't believe this is very important, so it's
|
||||
// not configurable from the command line for now.
|
||||
|
||||
void OnlineIvectorFeature::UpdateStatsForFrame(int32 t,
|
||||
BaseFloat weight) {
|
||||
int32 feat_dim = lda_normalized_->Dim();
|
||||
Vector<BaseFloat> feat(feat_dim), // features given to iVector extractor
|
||||
log_likes(info_.diag_ubm.NumGauss());
|
||||
lda_normalized_->GetFrame(t, &feat);
|
||||
info_.diag_ubm.LogLikelihoods(feat, &log_likes);
|
||||
// "posterior" stores the pruned posteriors for Gaussians in the UBM.
|
||||
std::vector<std::pair<int32, BaseFloat> > posterior;
|
||||
tot_ubm_loglike_ += weight *
|
||||
VectorToPosteriorEntry(log_likes, info_.num_gselect,
|
||||
info_.min_post, &posterior);
|
||||
for (size_t i = 0; i < posterior.size(); i++)
|
||||
posterior[i].second *= info_.posterior_scale * weight;
|
||||
lda_->GetFrame(t, &feat); // get feature without CMN.
|
||||
ivector_stats_.AccStats(info_.extractor, feat, posterior);
|
||||
}
|
||||
|
||||
void OnlineIvectorFeature::UpdateStatsUntilFrame(int32 frame) {
|
||||
KALDI_ASSERT(frame >= 0 && frame < this->NumFramesReady() &&
|
||||
!delta_weights_provided_);
|
||||
updated_with_no_delta_weights_ = true;
|
||||
|
||||
int32 ivector_period = info_.ivector_period;
|
||||
int32 num_cg_iters = info_.num_cg_iters;
|
||||
|
||||
for (; num_frames_stats_ <= frame; num_frames_stats_++) {
|
||||
int32 t = num_frames_stats_; // Frame whose stats we want to get.
|
||||
lda_normalized_->GetFrame(t, &feat);
|
||||
info_.diag_ubm.LogLikelihoods(feat, &log_likes);
|
||||
// "posterior" stores the pruned posteriors for Gaussians in the UBM.
|
||||
std::vector<std::pair<int32, BaseFloat> > posterior;
|
||||
tot_ubm_loglike_ += VectorToPosteriorEntry(log_likes, info_.num_gselect,
|
||||
info_.min_post, &posterior);
|
||||
for (size_t i = 0; i < posterior.size(); i++)
|
||||
posterior[i].second *= info_.posterior_scale;
|
||||
lda_->GetFrame(t, &feat); // get feature without CMN.
|
||||
ivector_stats_.AccStats(info_.extractor, feat, posterior);
|
||||
|
||||
int32 t = num_frames_stats_;
|
||||
UpdateStatsForFrame(t, 1.0);
|
||||
if ((!info_.use_most_recent_ivector && t % ivector_period == 0) ||
|
||||
(info_.use_most_recent_ivector && t == frame)) {
|
||||
ivector_stats_.GetIvector(num_cg_iters, ¤t_ivector_);
|
||||
|
@ -188,10 +212,57 @@ void OnlineIvectorFeature::UpdateStatsUntilFrame(int32 frame) {
|
|||
}
|
||||
}
|
||||
|
||||
void OnlineIvectorFeature::UpdateStatsUntilFrameWeighted(int32 frame) {
|
||||
KALDI_ASSERT(frame >= 0 && frame < this->NumFramesReady() &&
|
||||
delta_weights_provided_ &&
|
||||
! updated_with_no_delta_weights_ &&
|
||||
frame <= most_recent_frame_with_weight_);
|
||||
bool debug_weights = true;
|
||||
|
||||
int32 ivector_period = info_.ivector_period;
|
||||
int32 num_cg_iters = info_.num_cg_iters;
|
||||
|
||||
for (; num_frames_stats_ <= frame; num_frames_stats_++) {
|
||||
int32 t = num_frames_stats_;
|
||||
// Instead of just updating frame t, we update all frames that need updating
|
||||
// with index <= 1, in case old frames were reclassified as silence/nonsilence.
|
||||
while (!delta_weights_.empty() &&
|
||||
delta_weights_.top().first <= t) {
|
||||
std::pair<int32, BaseFloat> p = delta_weights_.top();
|
||||
delta_weights_.pop();
|
||||
int32 frame = p.first;
|
||||
BaseFloat weight = p.second;
|
||||
UpdateStatsForFrame(frame, weight);
|
||||
if (debug_weights) {
|
||||
if (current_frame_weight_debug_.size() <= frame)
|
||||
current_frame_weight_debug_.resize(frame + 1, 0.0);
|
||||
current_frame_weight_debug_[frame] += weight;
|
||||
KALDI_ASSERT(current_frame_weight_debug_[frame] >= -0.01 &&
|
||||
current_frame_weight_debug_[frame] <= 1.01);
|
||||
}
|
||||
}
|
||||
if ((!info_.use_most_recent_ivector && t % ivector_period == 0) ||
|
||||
(info_.use_most_recent_ivector && t == frame)) {
|
||||
ivector_stats_.GetIvector(num_cg_iters, ¤t_ivector_);
|
||||
if (!info_.use_most_recent_ivector) { // need to cache iVectors.
|
||||
int32 ivec_index = t / ivector_period;
|
||||
KALDI_ASSERT(ivec_index == static_cast<int32>(ivectors_history_.size()));
|
||||
ivectors_history_.push_back(new Vector<BaseFloat>(current_ivector_));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void OnlineIvectorFeature::GetFrame(int32 frame,
|
||||
VectorBase<BaseFloat> *feat) {
|
||||
UpdateStatsUntilFrame(info_.greedy_ivector_extractor ?
|
||||
lda_->NumFramesReady() - 1 : frame);
|
||||
int32 frame_to_update_until = (info_.greedy_ivector_extractor ?
|
||||
lda_->NumFramesReady() - 1 : frame);
|
||||
if (!delta_weights_provided_) // No silence weighting.
|
||||
UpdateStatsUntilFrame(frame_to_update_until);
|
||||
else
|
||||
UpdateStatsUntilFrameWeighted(frame_to_update_until);
|
||||
|
||||
KALDI_ASSERT(feat->Dim() == this->Dim());
|
||||
|
||||
if (info_.use_most_recent_ivector) {
|
||||
|
@ -218,8 +289,8 @@ void OnlineIvectorFeature::PrintDiagnostics() const {
|
|||
KALDI_VLOG(3) << "Processed no data.";
|
||||
} else {
|
||||
KALDI_VLOG(3) << "UBM log-likelihood was "
|
||||
<< (tot_ubm_loglike_ / num_frames_stats_)
|
||||
<< " per frame, over " << num_frames_stats_
|
||||
<< (tot_ubm_loglike_ / NumFrames())
|
||||
<< " per frame, over " << NumFrames()
|
||||
<< " frames.";
|
||||
|
||||
Vector<BaseFloat> temp_ivector(current_ivector_);
|
||||
|
@ -266,7 +337,9 @@ OnlineIvectorFeature::OnlineIvectorFeature(
|
|||
ivector_stats_(info_.extractor.IvectorDim(),
|
||||
info_.extractor.PriorOffset(),
|
||||
info_.max_count),
|
||||
num_frames_stats_(0), tot_ubm_loglike_(0.0) {
|
||||
num_frames_stats_(0), delta_weights_provided_(false),
|
||||
updated_with_no_delta_weights_(false),
|
||||
most_recent_frame_with_weight_(-1), tot_ubm_loglike_(0.0) {
|
||||
info.Check();
|
||||
KALDI_ASSERT(base_feature != NULL);
|
||||
splice_ = new OnlineSpliceFrames(info_.splice_opts, base_);
|
||||
|
@ -296,8 +369,8 @@ void OnlineIvectorFeature::SetAdaptationState(
|
|||
}
|
||||
|
||||
BaseFloat OnlineIvectorFeature::UbmLogLikePerFrame() const {
|
||||
if (num_frames_stats_ == 0) return 0;
|
||||
else return tot_ubm_loglike_ / num_frames_stats_;
|
||||
if (NumFrames() == 0) return 0;
|
||||
else return tot_ubm_loglike_ / NumFrames();
|
||||
}
|
||||
|
||||
BaseFloat OnlineIvectorFeature::ObjfImprPerFrame() const {
|
||||
|
@ -305,4 +378,206 @@ BaseFloat OnlineIvectorFeature::ObjfImprPerFrame() const {
|
|||
}
|
||||
|
||||
|
||||
OnlineSilenceWeighting::OnlineSilenceWeighting(
|
||||
const TransitionModel &trans_model,
|
||||
const OnlineSilenceWeightingConfig &config):
|
||||
trans_model_(trans_model), config_(config),
|
||||
num_frames_output_and_correct_(0) {
|
||||
vector<int32> silence_phones;
|
||||
SplitStringToIntegers(config.silence_phones_str, ":,", false,
|
||||
&silence_phones);
|
||||
for (size_t i = 0; i < silence_phones.size(); i++)
|
||||
silence_phones_.insert(silence_phones[i]);
|
||||
}
|
||||
|
||||
|
||||
void OnlineSilenceWeighting::ComputeCurrentTraceback(
|
||||
const LatticeFasterOnlineDecoder &decoder) {
|
||||
int32 num_frames_decoded = decoder.NumFramesDecoded(),
|
||||
num_frames_prev = frame_info_.size();
|
||||
// note, num_frames_prev is not the number of frames previously decoded,
|
||||
// it's the generally-larger number of frames that we were requested to
|
||||
// provide weights for.
|
||||
if (num_frames_prev < num_frames_decoded)
|
||||
frame_info_.resize(num_frames_decoded);
|
||||
if (num_frames_prev > num_frames_decoded &&
|
||||
frame_info_[num_frames_decoded].transition_id != -1)
|
||||
KALDI_ERR << "Number of frames decoded decreased"; // Likely bug
|
||||
|
||||
if (num_frames_decoded == 0)
|
||||
return;
|
||||
int32 frame = num_frames_decoded - 1;
|
||||
bool use_final_probs = false;
|
||||
LatticeFasterOnlineDecoder::BestPathIterator iter =
|
||||
decoder.BestPathEnd(use_final_probs, NULL);
|
||||
while (frame >= 0) {
|
||||
LatticeArc arc;
|
||||
arc.ilabel = 0;
|
||||
while (arc.ilabel == 0) // the while loop skips over input-epsilons
|
||||
iter = decoder.TraceBackBestPath(iter, &arc);
|
||||
// note, the iter.frame values are slightly unintuitively defined,
|
||||
// they are one less than you might expect.
|
||||
KALDI_ASSERT(iter.frame == frame - 1);
|
||||
|
||||
if (frame_info_[frame].token == iter.tok) {
|
||||
// we know that the traceback from this point back will be identical, so
|
||||
// no point tracing back further. Note: we are comparing memory addresses
|
||||
// of tokens of the decoder; this guarantees it's the same exact token
|
||||
// because tokens, once allocated on a frame, are only deleted, never
|
||||
// reallocated for that frame.
|
||||
break;
|
||||
}
|
||||
|
||||
if (num_frames_output_and_correct_ > frame)
|
||||
num_frames_output_and_correct_ = frame;
|
||||
|
||||
frame_info_[frame].token = iter.tok;
|
||||
frame_info_[frame].transition_id = arc.ilabel;
|
||||
frame--;
|
||||
// leave frame_info_.current_weight at zero for now (as set in the
|
||||
// constructor), reflecting that we haven't already output a weight for that
|
||||
// frame.
|
||||
}
|
||||
}
|
||||
|
||||
int32 OnlineSilenceWeighting::GetBeginFrame() {
|
||||
int32 max_duration = config_.max_state_duration;
|
||||
if (max_duration <= 0 || num_frames_output_and_correct_ == 0)
|
||||
return num_frames_output_and_correct_;
|
||||
|
||||
// t_last_untouched is the index of the last frame that is not newly touched
|
||||
// by ComputeCurrentTraceback. We are interested in whether it is part of a
|
||||
// run of length greater than max_duration, since this would force it
|
||||
// to be treated as silence (note: typically a non-silence phone that's very
|
||||
// long is really silence, for example this can happen with the word "mm").
|
||||
|
||||
int32 t_last_untouched = num_frames_output_and_correct_ - 1,
|
||||
t_end = frame_info_.size();
|
||||
int32 transition_id = frame_info_[t_last_untouched].transition_id;
|
||||
// no point searching longer than max_duration; when the length of the run is
|
||||
// at least that much, a longer length makes no difference.
|
||||
int32 lower_search_bound = std::max(0, t_last_untouched - max_duration),
|
||||
upper_search_bound = std::min(t_last_untouched + max_duration, t_end - 1),
|
||||
t_lower, t_upper;
|
||||
|
||||
// t_lower will be the first index in the run of equal transition-ids.
|
||||
for (t_lower = t_last_untouched;
|
||||
t_lower > lower_search_bound &&
|
||||
frame_info_[t_lower - 1].transition_id == transition_id; t_lower++);
|
||||
|
||||
// t_lower will be the last index in the run of equal transition-ids.
|
||||
for (t_upper = t_last_untouched;
|
||||
t_upper < upper_search_bound &&
|
||||
frame_info_[t_upper + 1].transition_id == transition_id; t_upper++);
|
||||
|
||||
int32 run_length = t_upper - t_lower + 1;
|
||||
if (run_length <= max_duration) {
|
||||
// we wouldn't treat this run as being silence, as it's within
|
||||
// the duration limit. So we return the default value
|
||||
// num_frames_output_and_correct_ as our lower bound for processing.
|
||||
return num_frames_output_and_correct_;
|
||||
}
|
||||
int32 old_run_length = t_last_untouched - t_lower + 1;
|
||||
if (old_run_length > max_duration) {
|
||||
// The run-length before we got this new data was already longer than the
|
||||
// max-duration, so would already have been treated as silence. therefore
|
||||
// we don't have to encompass it all- we just include a long enough length
|
||||
// in the region we are going to process, that the run-length in that region
|
||||
// is longer than max_duration.
|
||||
int32 ans = t_upper - max_duration;
|
||||
KALDI_ASSERT(ans >= t_lower);
|
||||
return ans;
|
||||
} else {
|
||||
return t_lower;
|
||||
}
|
||||
}
|
||||
|
||||
void OnlineSilenceWeighting::GetDeltaWeights(
|
||||
int32 num_frames_ready,
|
||||
std::vector<std::pair<int32, BaseFloat> > *delta_weights) {
|
||||
const int32 max_state_duration = config_.max_state_duration;
|
||||
const BaseFloat silence_weight = config_.silence_weight;
|
||||
|
||||
delta_weights->clear();
|
||||
|
||||
if (frame_info_.size() < static_cast<size_t>(num_frames_ready))
|
||||
frame_info_.resize(num_frames_ready);
|
||||
|
||||
// we may have to make begin_frame earlier than num_frames_output_and_correct_
|
||||
// so that max_state_duration is properly enforced. GetBeginFrame() handles
|
||||
// this logic.
|
||||
int32 begin_frame = GetBeginFrame(),
|
||||
frames_out = static_cast<int32>(frame_info_.size()) - begin_frame;
|
||||
// frames_out is the number of frames we will output.
|
||||
KALDI_ASSERT(frames_out >= 0);
|
||||
vector<BaseFloat> frame_weight(frames_out, 1.0);
|
||||
// we will frame_weight to the value silence_weight for silence frames and for
|
||||
// transition-ids that repeat with duration > max_state_duration. Frames newer
|
||||
// than the most recent traceback will get a weight equal to the weight for the
|
||||
// most recent frame in the traceback; or the silence weight, if there is no
|
||||
// traceback at all available yet.
|
||||
|
||||
// First treat some special cases.
|
||||
if (frames_out == 0) // Nothing to output.
|
||||
return;
|
||||
if (frame_info_[begin_frame].transition_id == -1) {
|
||||
// We do not have any traceback at all within the frames we are to output...
|
||||
// find the most recent weight that we output and apply the same weight to
|
||||
// all the new output; or output the silence weight, if nothing was output.
|
||||
BaseFloat weight = (begin_frame == 0 ? silence_weight :
|
||||
frame_info_[begin_frame - 1].current_weight);
|
||||
for (int32 offset = 0; offset < frames_out; offset++)
|
||||
frame_weight[offset] = weight;
|
||||
} else {
|
||||
int32 current_run_start_offset = 0;
|
||||
for (int32 offset = 0; offset < frames_out; offset++) {
|
||||
int32 frame = begin_frame + offset;
|
||||
int32 transition_id = frame_info_[frame].transition_id;
|
||||
if (transition_id == -1) {
|
||||
// this frame does not yet have a decoder traceback, so just
|
||||
// duplicate the silence/non-silence status of the most recent
|
||||
// frame we have a traceback for (probably a reasonable guess).
|
||||
frame_weight[offset] = frame_weight[offset - 1];
|
||||
} else {
|
||||
int32 phone = trans_model_.TransitionIdToPhone(transition_id);
|
||||
bool is_silence = (silence_phones_.count(phone) != 0);
|
||||
if (is_silence)
|
||||
frame_weight[offset] = silence_weight;
|
||||
// now deal with max-duration issues.
|
||||
if (max_state_duration > 0 &&
|
||||
(offset + 1 == frames_out ||
|
||||
transition_id != frame_info_[frame + 1].transition_id)) {
|
||||
// If this is the last frame of a run...
|
||||
int32 run_length = offset - current_run_start_offset + 1;
|
||||
if (run_length >= max_state_duration) {
|
||||
// treat runs of the same transition-id longer than the max, as
|
||||
// silence, even if they were not silence.
|
||||
for (int32 offset2 = current_run_start_offset;
|
||||
offset2 <= offset; offset2++)
|
||||
frame_weight[offset2] = silence_weight;
|
||||
}
|
||||
if (offset + 1 < frames_out)
|
||||
current_run_start_offset = offset + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Now commit the stats...
|
||||
for (int32 offset = 0; offset < frames_out; offset++) {
|
||||
int32 frame = begin_frame + offset;
|
||||
BaseFloat old_weight = frame_info_[frame].current_weight,
|
||||
new_weight = frame_weight[offset],
|
||||
weight_diff = new_weight - old_weight;
|
||||
frame_info_[frame].current_weight = new_weight;
|
||||
KALDI_VLOG(6) << "Weight for frame " << frame << " changing from "
|
||||
<< old_weight << " to " << new_weight;
|
||||
// Even if the delta-weight is zero for the last frame, we provide it,
|
||||
// because the identity of the most recent frame with a weight is used in
|
||||
// some debugging/checking code.
|
||||
if (weight_diff != 0.0 || offset + 1 == frames_out)
|
||||
delta_weights->push_back(std::make_pair(frame, weight_diff));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace kaldi
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "gmm/diag-gmm.h"
|
||||
#include "feat/online-feature.h"
|
||||
#include "ivector/ivector-extractor.h"
|
||||
#include "decoder/lattice-faster-online-decoder.h"
|
||||
|
||||
namespace kaldi {
|
||||
/// @addtogroup onlinefeat OnlineFeatureExtraction
|
||||
|
@ -75,6 +76,9 @@ struct OnlineIvectorExtractionConfig {
|
|||
// atypical-looking iVectors for very long utterances.
|
||||
// Interpret this as a number of frames times
|
||||
// posterior_scale, typically 1/10 of a frame count.
|
||||
|
||||
int32 num_cg_iters; // set to 15. I don't believe this is very important, so it's
|
||||
// not configurable from the command line for now.
|
||||
|
||||
|
||||
// If use_most_recent_ivector is true, we always return the most recent
|
||||
|
@ -94,10 +98,10 @@ struct OnlineIvectorExtractionConfig {
|
|||
// (assuming you provided info from a previous utterance of the same speaker
|
||||
// by calling SetAdaptationState()).
|
||||
BaseFloat max_remembered_frames;
|
||||
|
||||
|
||||
OnlineIvectorExtractionConfig(): ivector_period(10), num_gselect(5),
|
||||
min_post(0.025), posterior_scale(0.1),
|
||||
max_count(0.0),
|
||||
max_count(0.0), num_cg_iters(15),
|
||||
use_most_recent_ivector(true),
|
||||
greedy_ivector_extractor(false),
|
||||
max_remembered_frames(1000) { }
|
||||
|
@ -112,7 +116,8 @@ struct OnlineIvectorExtractionConfig {
|
|||
"iVector extraction");
|
||||
po->Register("cmvn-config", &cmvn_config_rxfilename, "Configuration "
|
||||
"file for online CMVN features (e.g. conf/online_cmvn.conf),"
|
||||
"only used for iVector extraction");
|
||||
"only used for iVector extraction. Contains options "
|
||||
"as for the program 'apply-cmvn-online'");
|
||||
po->Register("splice-config", &splice_config_rxfilename, "Configuration file "
|
||||
"for frame splicing (--left-context and --right-context "
|
||||
"options); used for iVector extraction.");
|
||||
|
@ -144,7 +149,8 @@ struct OnlineIvectorExtractionConfig {
|
|||
"number of frames of adaptation history that we carry through "
|
||||
"to later utterances of the same speaker (having a finite "
|
||||
"number allows the speaker adaptation state to change over "
|
||||
"time");
|
||||
"time). Interpret as a real frame count, i.e. not a count "
|
||||
"scaled by --posterior-scale.");
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -169,6 +175,7 @@ struct OnlineIvectorExtractionInfo {
|
|||
BaseFloat min_post;
|
||||
BaseFloat posterior_scale;
|
||||
BaseFloat max_count;
|
||||
int32 num_cg_iters;
|
||||
bool use_most_recent_ivector;
|
||||
bool greedy_ivector_extractor;
|
||||
BaseFloat max_remembered_frames;
|
||||
|
@ -244,6 +251,14 @@ class OnlineIvectorFeature: public OnlineFeatureInterface {
|
|||
/// delete it while this class or others copied from it still exist.
|
||||
explicit OnlineIvectorFeature(const OnlineIvectorExtractionInfo &info,
|
||||
OnlineFeatureInterface *base_feature);
|
||||
|
||||
// This version of the constructor accepts per-frame weights (relates to
|
||||
// downweighting silence). This is intended for use in offline operation,
|
||||
// i.e. during training. [will implement this when needed.]
|
||||
//explicit OnlineIvectorFeature(const OnlineIvectorExtractionInfo &info,
|
||||
// std::vector<BaseFloat> frame_weights,
|
||||
//OnlineFeatureInterface *base_feature);
|
||||
|
||||
|
||||
// Member functions from OnlineFeatureInterface:
|
||||
|
||||
|
@ -253,7 +268,6 @@ class OnlineIvectorFeature: public OnlineFeatureInterface {
|
|||
virtual int32 NumFramesReady() const;
|
||||
virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat);
|
||||
|
||||
|
||||
/// Set the adaptation state to a particular value, e.g. reflecting previous
|
||||
/// utterances of the same speaker; this will generally be called after
|
||||
/// constructing a new instance of this class.
|
||||
|
@ -275,9 +289,39 @@ class OnlineIvectorFeature: public OnlineFeatureInterface {
|
|||
// Objective improvement per frame from iVector estimation, versus default iVector
|
||||
// value, measured at utterance end.
|
||||
BaseFloat ObjfImprPerFrame() const;
|
||||
|
||||
// returns number of frames seen (but not counting the posterior-scale).
|
||||
BaseFloat NumFrames() const {
|
||||
return ivector_stats_.NumFrames() / info_.posterior_scale;
|
||||
}
|
||||
|
||||
|
||||
// If you are downweighting silence, you can call
|
||||
// OnlineSilenceWeighting::GetDeltaWeights and supply the output to this class
|
||||
// using UpdateFrameWeights(). The reason why this call happens outside this
|
||||
// class, rather than this class pulling in the data weights, relates to
|
||||
// multi-threaded operation and also from not wanting this class to have
|
||||
// excessive dependencies.
|
||||
//
|
||||
// You must either always call this as soon as new data becomes available
|
||||
// (ideally just after calling AcceptWaveform), or never call it for the
|
||||
// lifetime of this object.
|
||||
void UpdateFrameWeights(
|
||||
const std::vector<std::pair<int32, BaseFloat> > &delta_weights);
|
||||
|
||||
private:
|
||||
virtual void UpdateStatsUntilFrame(int32 frame);
|
||||
// this function adds "weight" to the stats for frame "frame".
|
||||
void UpdateStatsForFrame(int32 frame,
|
||||
BaseFloat weight);
|
||||
|
||||
// This is the original UpdateStatsUntilFrame that is called when there is
|
||||
// no data-weighting involved.
|
||||
void UpdateStatsUntilFrame(int32 frame);
|
||||
|
||||
// This is the new UpdateStatsUntilFrame that is called when there is
|
||||
// data-weighting (i.e. when the user has been calling UpdateFrameWeights()).
|
||||
void UpdateStatsUntilFrameWeighted(int32 frame);
|
||||
|
||||
void PrintDiagnostics() const;
|
||||
|
||||
const OnlineIvectorExtractionInfo &info_;
|
||||
|
@ -295,9 +339,37 @@ class OnlineIvectorFeature: public OnlineFeatureInterface {
|
|||
OnlineIvectorEstimationStats ivector_stats_;
|
||||
|
||||
/// num_frames_stats_ is the number of frames of data we have already
|
||||
/// accumulated from this utterance and put in ivector_stats_.
|
||||
/// accumulated from this utterance and put in ivector_stats_. Each frame t <
|
||||
/// num_frames_stats_ is in the stats. In case you are doing the
|
||||
/// silence-weighted iVector estimation, with UpdateFrameWeights() being
|
||||
/// called, this variable is still used but you may later have to revisit
|
||||
/// earlier frames to adjust their weights... see the code.
|
||||
int32 num_frames_stats_;
|
||||
|
||||
/// delta_weights_ is written to by UpdateFrameWeights,
|
||||
/// in the case where the iVector estimation is silence-weighted using the decoder
|
||||
/// traceback. Its elements are consumed by UpdateStatsUntilFrameWeighted().
|
||||
/// We provide std::greater<std::pair<int32, BaseFloat> > > as the comparison type
|
||||
/// (default is std::less) so that the lowest-numbered frame, not the highest-numbered
|
||||
/// one, will be returned by top().
|
||||
std::priority_queue<std::pair<int32, BaseFloat>,
|
||||
std::vector<std::pair<int32, BaseFloat> >,
|
||||
std::greater<std::pair<int32, BaseFloat> > > delta_weights_;
|
||||
|
||||
/// this is only used for validating that the frame-weighting code is not buggy.
|
||||
std::vector<BaseFloat> current_frame_weight_debug_;
|
||||
|
||||
/// delta_weights_provided_ is set to true if UpdateFrameWeights was ever called; it's
|
||||
/// used to detect wrong usage of this class.
|
||||
bool delta_weights_provided_;
|
||||
/// The following is also used to detect wrong usage of this class; it's set
|
||||
/// to true if UpdateStatsUntilFrame() was ever called.
|
||||
bool updated_with_no_delta_weights_;
|
||||
|
||||
/// if delta_weights_ was ever called, this keeps track of the most recent
|
||||
/// frame that ever had a weight. It's mostly for detecting errors.
|
||||
int32 most_recent_frame_with_weight_;
|
||||
|
||||
/// The following is only needed for diagnostics.
|
||||
double tot_ubm_loglike_;
|
||||
|
||||
|
@ -312,10 +384,131 @@ class OnlineIvectorFeature: public OnlineFeatureInterface {
|
|||
/// ivectors_history_[i] contains the iVector we estimated on
|
||||
/// frame t = i * info_.ivector_period.
|
||||
std::vector<Vector<BaseFloat>* > ivectors_history_;
|
||||
|
||||
|
||||
};
|
||||
|
||||
|
||||
struct OnlineSilenceWeightingConfig {
|
||||
std::string silence_phones_str;
|
||||
// The weighting factor that we apply to silence phones in the iVector
|
||||
// extraction. This option is only relevant if the --silence-phones option is
|
||||
// set.
|
||||
BaseFloat silence_weight;
|
||||
|
||||
// Transition-ids that get repeated at least this many times (if
|
||||
// max_state_duration > 0) are treated as silence.
|
||||
BaseFloat max_state_duration;
|
||||
|
||||
// This is the scale that we apply to data that we don't yet have a decoder
|
||||
// traceback for, in the online silence
|
||||
BaseFloat new_data_weight;
|
||||
|
||||
bool Active() const {
|
||||
return !silence_phones_str.empty() && silence_weight != 1.0;
|
||||
}
|
||||
|
||||
OnlineSilenceWeightingConfig():
|
||||
silence_weight(1.0), max_state_duration(-1) { }
|
||||
|
||||
void Register(OptionsItf *po) {
|
||||
po->Register("silence-phones", &silence_phones_str, "(RE weighting in "
|
||||
"iVector estimation for online decoding) List of integer ids of "
|
||||
"silence phones, separated by colons (or commas). Data that "
|
||||
"(according to the traceback of the decoder) corresponds to "
|
||||
"these phones will be downweighted by --silence-weight.");
|
||||
po->Register("silence-weight", &silence_weight, "(RE weighting in "
|
||||
"iVector estimation for online decoding) Weighting factor for frames "
|
||||
"that the decoder trace-back identifies as silence; only "
|
||||
"relevant if the --silence-phones option is set.");
|
||||
po->Register("max-state-duration", &max_state_duration, "(RE weighting in "
|
||||
"iVector estimation for online decoding) Maximum allowed "
|
||||
"duration of a single transition-id; runs with durations longer "
|
||||
"than this will be weighted down to the silence-weight.");
|
||||
}
|
||||
// e.g. prefix = "ivector-silence-weighting"
|
||||
void RegisterWithPrefix(std::string prefix, OptionsItf *po) {
|
||||
ParseOptions po_prefix(prefix, po);
|
||||
this->Register(&po_prefix);
|
||||
}
|
||||
};
|
||||
|
||||
// This class is responsible for keeping track of the best-path traceback from
|
||||
// the decoder (efficiently) and computing a weighting of the data based on the
|
||||
// classification of frames as silence (or not silence)... also with a duration
|
||||
// limitation, so data from a very long run of the same transition-id will get
|
||||
// weighted down. (this is often associated with misrecognition or silence).
|
||||
class OnlineSilenceWeighting {
|
||||
public:
|
||||
// Note: you would initialize a new copy of this object for each new
|
||||
// utterance.
|
||||
OnlineSilenceWeighting(const TransitionModel &trans_model,
|
||||
const OnlineSilenceWeightingConfig &config);
|
||||
|
||||
bool Active() const { return config_.Active(); }
|
||||
|
||||
// This should be called before GetDeltaWeights, so this class knows about the
|
||||
// traceback info from the decoder. It records the traceback information from
|
||||
// the decoder using its BestPathEnd() and related functions.
|
||||
void ComputeCurrentTraceback(const LatticeFasterOnlineDecoder &decoder);
|
||||
|
||||
// Calling this function gets the changes in weight that require us to modify
|
||||
// the stats... the output format is (frame-index, delta-weight). The
|
||||
// num_frames_ready argument is the number of frames available at the input
|
||||
// (or equivalently, output) of the online iVector extractor class, which may
|
||||
// be more than the currently availabl decoder traceback. How many frames
|
||||
// of weights it outputs depends on how much "num_frames_ready" increased
|
||||
// since last time we called this function, and whether the decoder traceback
|
||||
// changed. Negative delta_weights might occur if frames previously
|
||||
// classified as non-silence become classified as silence if the decoder's
|
||||
// traceback changes. You must call this function with "num_frames_ready"
|
||||
// arguments that only increase, not decrease, with time. You would provide
|
||||
// this output to class OnlineIvectorFeature by calling its function
|
||||
// UpdateFrameWeights with the output.
|
||||
void GetDeltaWeights(
|
||||
int32 num_frames_ready,
|
||||
std::vector<std::pair<int32, BaseFloat> > *delta_weights);
|
||||
|
||||
private:
|
||||
const TransitionModel &trans_model_;
|
||||
const OnlineSilenceWeightingConfig &config_;
|
||||
|
||||
unordered_set<int32> silence_phones_;
|
||||
|
||||
struct FrameInfo {
|
||||
//The only reason we need the token pointer is to know far back we have to
|
||||
// trace before the traceback is the same as what we previously traced back.
|
||||
void *token;
|
||||
int32 transition_id;
|
||||
// current_weight is the weight we've previously told the iVector
|
||||
// extractor to use for this frame, if any. It may not equal the
|
||||
// weight we "want" it to use (any difference between the two will
|
||||
// be output when the user calls GetDeltaWeights().
|
||||
BaseFloat current_weight;
|
||||
FrameInfo(): token(NULL), transition_id(-1), current_weight(0.0) {}
|
||||
};
|
||||
|
||||
// gets the frame at which we need to begin our processing in
|
||||
// GetDeltaWeights... normally this is equal to
|
||||
// num_frames_output_and_correct_, but it may be earlier in case
|
||||
// max_state_duration is relevant.
|
||||
int32 GetBeginFrame();
|
||||
|
||||
std::vector<FrameInfo> frame_info_;
|
||||
|
||||
// This records how many frames have been output and that currently reflect
|
||||
// the traceback accurately. It is used to avoid GetDeltaWeights() having to
|
||||
// visit each frame as far back as t = 0, each time it is called.
|
||||
// GetDeltaWeights() sets this to the number of frames that it output, and
|
||||
// ComputeCurrentTraceback() then reduces it to however far it traced back.
|
||||
// However, we may have to go further back in time than this in order to
|
||||
// properly honor the "max-state-duration" config. This, if needed, is done
|
||||
// in GetDeltaWeights() before outputting the delta weights.
|
||||
int32 num_frames_output_and_correct_;
|
||||
};
|
||||
|
||||
|
||||
/// @} End of "addtogroup onlinefeat"
|
||||
} // namespace kaldi
|
||||
|
||||
#endif // KALDI_ONLINE2_ONLINE_NNET2_FEATURE_PIPELINE_H_
|
||||
|
||||
|
|
|
@ -116,12 +116,13 @@ SingleUtteranceNnet2DecoderThreaded::SingleUtteranceNnet2DecoderThreaded(
|
|||
const fst::Fst<fst::StdArc> &fst,
|
||||
const OnlineNnet2FeaturePipelineInfo &feature_info,
|
||||
const OnlineIvectorExtractorAdaptationState &adaptation_state):
|
||||
config_(config), am_nnet_(am_nnet), tmodel_(tmodel), sampling_rate_(0.0),
|
||||
num_samples_received_(0), input_finished_(false),
|
||||
feature_pipeline_(feature_info), feature_buffer_start_frame_(0),
|
||||
feature_buffer_finished_(false), decodable_(tmodel),
|
||||
num_frames_decoded_(0), decoder_(fst, config_.decoder_opts),
|
||||
abort_(false), error_(false) {
|
||||
config_(config), am_nnet_(am_nnet), tmodel_(tmodel), sampling_rate_(0.0),
|
||||
num_samples_received_(0), input_finished_(false),
|
||||
feature_pipeline_(feature_info),
|
||||
silence_weighting_(tmodel, feature_info.silence_weighting_config),
|
||||
decodable_(tmodel),
|
||||
num_frames_decoded_(0), decoder_(fst, config_.decoder_opts),
|
||||
abort_(false), error_(false) {
|
||||
// if the user supplies an adaptation state that was not freshly initialized,
|
||||
// it means that we take the adaptation state from the previous
|
||||
// utterance(s)... this only makes sense if theose previous utterance(s) are
|
||||
|
@ -137,28 +138,13 @@ SingleUtteranceNnet2DecoderThreaded::SingleUtteranceNnet2DecoderThreaded(
|
|||
// will not be called. So we don't have to be careful about setting the
|
||||
// thread pointers to NULL after we've joined them.
|
||||
if ((ret=pthread_create(&(threads_[0]),
|
||||
&pthread_attr, RunFeatureExtraction,
|
||||
&pthread_attr, RunNnetEvaluation,
|
||||
(void*)this)) != 0) {
|
||||
const char *c = strerror(ret);
|
||||
if (c == NULL) { c = "[NULL]"; }
|
||||
KALDI_ERR << "Error creating thread, errno was: " << c;
|
||||
}
|
||||
if ((ret=pthread_create(&(threads_[1]),
|
||||
&pthread_attr, RunNnetEvaluation,
|
||||
(void*)this)) != 0) {
|
||||
const char *c = strerror(ret);
|
||||
if (c == NULL) { c = "[NULL]"; }
|
||||
bool error = true;
|
||||
AbortAllThreads(error);
|
||||
KALDI_WARN << "Error creating thread, errno was: " << c
|
||||
<< " (will rejoin already-created threads).";
|
||||
if (pthread_join(threads_[0], NULL)) {
|
||||
KALDI_ERR << "Error rejoining thread.";
|
||||
} else {
|
||||
KALDI_ERR << "Error creating thread, errno was: " << c;
|
||||
}
|
||||
}
|
||||
if ((ret=pthread_create(&(threads_[2]),
|
||||
&pthread_attr, RunDecoderSearch,
|
||||
(void*)this)) != 0) {
|
||||
const char *c = strerror(ret);
|
||||
|
@ -167,7 +153,7 @@ SingleUtteranceNnet2DecoderThreaded::SingleUtteranceNnet2DecoderThreaded(
|
|||
AbortAllThreads(error);
|
||||
KALDI_WARN << "Error creating thread, errno was: " << c
|
||||
<< " (will rejoin already-created threads).";
|
||||
if (pthread_join(threads_[0], NULL) || pthread_join(threads_[1], NULL)) {
|
||||
if (pthread_join(threads_[0], NULL)) {
|
||||
KALDI_ERR << "Error rejoining thread.";
|
||||
} else {
|
||||
KALDI_ERR << "Error creating thread, errno was: " << c;
|
||||
|
@ -185,8 +171,10 @@ SingleUtteranceNnet2DecoderThreaded::~SingleUtteranceNnet2DecoderThreaded() {
|
|||
// join all the threads (this avoids leaving zombie threads around, or threads
|
||||
// that might be accessing deconstructed object).
|
||||
WaitForAllThreads();
|
||||
DeletePointers(&input_waveform_);
|
||||
DeletePointers(&feature_buffer_);
|
||||
while (!input_waveform_.empty()) {
|
||||
delete input_waveform_.front();
|
||||
input_waveform_.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
void SingleUtteranceNnet2DecoderThreaded::AcceptWaveform(
|
||||
|
@ -313,7 +301,6 @@ void SingleUtteranceNnet2DecoderThreaded::AbortAllThreads(bool error) {
|
|||
if (error)
|
||||
error_ = true;
|
||||
waveform_synchronizer_.SetAbort();
|
||||
feature_synchronizer_.SetAbort();
|
||||
decodable_synchronizer_.SetAbort();
|
||||
}
|
||||
|
||||
|
@ -324,22 +311,6 @@ int32 SingleUtteranceNnet2DecoderThreaded::NumFramesDecoded() const {
|
|||
return ans;
|
||||
}
|
||||
|
||||
void* SingleUtteranceNnet2DecoderThreaded::RunFeatureExtraction(void *ptr_in) {
|
||||
SingleUtteranceNnet2DecoderThreaded *me =
|
||||
reinterpret_cast<SingleUtteranceNnet2DecoderThreaded*>(ptr_in);
|
||||
try {
|
||||
if (!me->RunFeatureExtractionInternal() && !me->abort_)
|
||||
KALDI_ERR << "Returned abnormally and abort was not called";
|
||||
} catch(const std::exception &e) {
|
||||
KALDI_WARN << "Caught exception: " << e.what();
|
||||
// if an error happened in one thread, we need to make sure the other
|
||||
// threads can exit too.
|
||||
bool error = true;
|
||||
me->AbortAllThreads(error);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
void* SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluation(void *ptr_in) {
|
||||
SingleUtteranceNnet2DecoderThreaded *me =
|
||||
reinterpret_cast<SingleUtteranceNnet2DecoderThreaded*>(ptr_in);
|
||||
|
@ -373,7 +344,7 @@ void* SingleUtteranceNnet2DecoderThreaded::RunDecoderSearch(void *ptr_in) {
|
|||
|
||||
|
||||
void SingleUtteranceNnet2DecoderThreaded::WaitForAllThreads() {
|
||||
for (int32 i = 0; i < 3; i++) { // there are 3 spawned threads.
|
||||
for (int32 i = 0; i < 2; i++) { // there are 2 spawned threads.
|
||||
pthread_t &thread = threads_[i];
|
||||
if (KALDI_PTHREAD_PTR(thread) != 0) {
|
||||
if (pthread_join(thread, NULL)) {
|
||||
|
@ -387,94 +358,6 @@ void SingleUtteranceNnet2DecoderThreaded::WaitForAllThreads() {
|
|||
}
|
||||
}
|
||||
|
||||
bool SingleUtteranceNnet2DecoderThreaded::RunFeatureExtractionInternal() {
|
||||
// Note: if any of the functions Lock, UnlockSuccess, UnlockFailure return
|
||||
// false, it is because AbortAllThreads() called, and we return false
|
||||
// immediately.
|
||||
|
||||
// num_frames_output is a local variable that keeps track of how many
|
||||
// frames we have output to the feature buffer, for this utterance.
|
||||
int32 num_frames_output = 0;
|
||||
|
||||
while (true) {
|
||||
// First deal with accepting input.
|
||||
if (!waveform_synchronizer_.Lock(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
if (input_waveform_.empty()) {
|
||||
if (input_finished_ &&
|
||||
!feature_pipeline_.IsLastFrame(feature_pipeline_.NumFramesReady()-1)) {
|
||||
// the main thread called InputFinished() and set input_finished_, and
|
||||
// we haven't yet registered that fact. This is progress so
|
||||
// UnlockSuccess().
|
||||
feature_pipeline_.InputFinished();
|
||||
if (!waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
} else {
|
||||
// there was no input to process. However, we only call UnlockFailure() if we
|
||||
// are blocked on the fact that there is no input to process; otherwise we
|
||||
// call UnlockSuccess().
|
||||
if (num_frames_output == feature_pipeline_.NumFramesReady()) {
|
||||
// we need to wait until there is more input.
|
||||
if (!waveform_synchronizer_.UnlockFailure(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
} else { // we can keep looping.
|
||||
if (!waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else { // there is more wav data.
|
||||
{ // braces clarify scope of locking.
|
||||
feature_pipeline_mutex_.Lock();
|
||||
for (size_t i = 0; i < input_waveform_.size(); i++)
|
||||
if (input_waveform_[i]->Dim() != 0)
|
||||
feature_pipeline_.AcceptWaveform(sampling_rate_, *input_waveform_[i]);
|
||||
feature_pipeline_mutex_.Unlock();
|
||||
}
|
||||
DeletePointers(&input_waveform_);
|
||||
input_waveform_.clear();
|
||||
if (!waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!feature_synchronizer_.Lock(ThreadSynchronizer::kProducer)) return false;
|
||||
|
||||
if (feature_buffer_.size() >= config_.max_buffered_features) {
|
||||
// we need to block on the output buffer.
|
||||
if (!feature_synchronizer_.UnlockFailure(ThreadSynchronizer::kProducer))
|
||||
return false;
|
||||
} else {
|
||||
|
||||
{ // braces clarify scope of locking.
|
||||
feature_pipeline_mutex_.Lock();
|
||||
// There is buffer space available; deal with producing output.
|
||||
int32 cur_size = feature_buffer_.size(),
|
||||
batch_size = config_.feature_batch_size,
|
||||
feat_dim = feature_pipeline_.Dim();
|
||||
|
||||
for (int32 t = feature_buffer_start_frame_ +
|
||||
static_cast<int32>(feature_buffer_.size());
|
||||
t < feature_buffer_start_frame_ + config_.max_buffered_features &&
|
||||
t < feature_buffer_start_frame_ + cur_size + batch_size &&
|
||||
t < feature_pipeline_.NumFramesReady(); t++) {
|
||||
Vector<BaseFloat> *feats = new Vector<BaseFloat>(feat_dim, kUndefined);
|
||||
// Note: most of the actual computation occurs.
|
||||
feature_pipeline_.GetFrame(t, feats);
|
||||
feature_buffer_.push_back(feats);
|
||||
}
|
||||
num_frames_output = feature_buffer_start_frame_ + feature_buffer_.size();
|
||||
if (feature_pipeline_.IsLastFrame(num_frames_output - 1)) {
|
||||
// e.g. user called InputFinished() and we already saw the last frame.
|
||||
feature_buffer_finished_ = true;
|
||||
}
|
||||
feature_pipeline_mutex_.Unlock();
|
||||
}
|
||||
if (!feature_synchronizer_.UnlockSuccess(ThreadSynchronizer::kProducer))
|
||||
return false;
|
||||
|
||||
if (feature_buffer_finished_) return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SingleUtteranceNnet2DecoderThreaded::ProcessLoglikes(
|
||||
const CuVector<BaseFloat> &log_inv_prior,
|
||||
|
@ -489,6 +372,56 @@ void SingleUtteranceNnet2DecoderThreaded::ProcessLoglikes(
|
|||
}
|
||||
}
|
||||
|
||||
// called from RunNnetEvaluationInternal(). Returns true in the normal case,
|
||||
// false on error; if it returns false, then we expect that the calling thread
|
||||
// will terminate. This assumes the calling thread has already
|
||||
// locked feature_pipeline_mutex_.
|
||||
bool SingleUtteranceNnet2DecoderThreaded::FeatureComputation(
|
||||
int32 num_frames_output) {
|
||||
|
||||
int32 num_frames_ready = feature_pipeline_.NumFramesReady(),
|
||||
num_frames_usable = num_frames_ready - num_frames_output;
|
||||
bool features_done = feature_pipeline_.IsLastFrame(num_frames_ready - 1);
|
||||
KALDI_ASSERT(num_frames_usable >= 0);
|
||||
if (features_done) {
|
||||
return true; // nothing to do. (but not an error).
|
||||
} else {
|
||||
if (num_frames_usable >= config_.nnet_batch_size)
|
||||
return true; // We don't need more data yet.
|
||||
|
||||
// Now try to get more data, if we can.
|
||||
if (!waveform_synchronizer_.Lock(ThreadSynchronizer::kConsumer)) {
|
||||
return false;
|
||||
}
|
||||
// we've got the lock.
|
||||
if (input_waveform_.empty()) { // we got no data
|
||||
if (input_finished_ &&
|
||||
!feature_pipeline_.IsLastFrame(feature_pipeline_.NumFramesReady()-1)) {
|
||||
// the main thread called InputFinished() and set input_finished_, and
|
||||
// we haven't yet registered that fact. This is progress so
|
||||
// unlock with UnlockSuccess().
|
||||
feature_pipeline_.InputFinished();
|
||||
return waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer);
|
||||
} else {
|
||||
// there is no progress. Unlock with UnlockFailure() so the next call to
|
||||
// waveform_synchronizer_.Lock() will lock.
|
||||
return waveform_synchronizer_.UnlockFailure(ThreadSynchronizer::kConsumer);
|
||||
}
|
||||
} else { // we got some data. Only take enough of the waveform to
|
||||
// give us a maximum nnet batch size of frames to decode.
|
||||
while (num_frames_usable < config_.nnet_batch_size &&
|
||||
!input_waveform_.empty()) {
|
||||
feature_pipeline_.AcceptWaveform(sampling_rate_, *input_waveform_.front());
|
||||
delete input_waveform_.front();
|
||||
input_waveform_.pop_front();
|
||||
num_frames_ready = feature_pipeline_.NumFramesReady();
|
||||
num_frames_usable = num_frames_ready - num_frames_output;
|
||||
}
|
||||
return waveform_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluationInternal() {
|
||||
// if any of the Lock/Unlock functions return false, it's because AbortAllThreads()
|
||||
// was called.
|
||||
|
@ -497,7 +430,7 @@ bool SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluationInternal() {
|
|||
// re-computing things we've already computed.
|
||||
bool pad_input = true;
|
||||
nnet2::NnetOnlineComputer computer(am_nnet_.GetNnet(), pad_input);
|
||||
|
||||
|
||||
// we declare the following as CuVector just to enable GPU support, but
|
||||
// we expect this code to be run on CPU in the normal case.
|
||||
CuVector<BaseFloat> log_inv_prior(am_nnet_.Priors());
|
||||
|
@ -506,46 +439,58 @@ bool SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluationInternal() {
|
|||
log_inv_prior.Scale(-1.0);
|
||||
|
||||
int32 num_frames_output = 0;
|
||||
|
||||
|
||||
while (true) {
|
||||
bool last_time = false;
|
||||
|
||||
if (!feature_synchronizer_.Lock(ThreadSynchronizer::kConsumer))
|
||||
|
||||
/****** Begin locking of feature pipeline mutex. ******/
|
||||
feature_pipeline_mutex_.Lock();
|
||||
if (!FeatureComputation(num_frames_output)) { // error
|
||||
feature_pipeline_mutex_.Unlock();
|
||||
return false;
|
||||
}
|
||||
// take care of silence weighting.
|
||||
if (silence_weighting_.Active()) {
|
||||
silence_weighting_mutex_.Lock();
|
||||
std::vector<std::pair<int32, BaseFloat> > delta_weights;
|
||||
silence_weighting_.GetDeltaWeights(feature_pipeline_.NumFramesReady(),
|
||||
&delta_weights);
|
||||
silence_weighting_mutex_.Unlock();
|
||||
feature_pipeline_.UpdateFrameWeights(delta_weights);
|
||||
}
|
||||
|
||||
int32 num_frames_ready = feature_pipeline_.NumFramesReady(),
|
||||
num_frames_usable = num_frames_ready - num_frames_output;
|
||||
bool features_done = feature_pipeline_.IsLastFrame(num_frames_ready - 1);
|
||||
|
||||
int32 num_frames_evaluate = std::min<int32>(num_frames_usable,
|
||||
config_.nnet_batch_size);
|
||||
|
||||
Matrix<BaseFloat> feats;
|
||||
if (num_frames_evaluate > 0) {
|
||||
// we have something to do...
|
||||
feats.Resize(num_frames_evaluate, feature_pipeline_.Dim());
|
||||
for (int32 i = 0; i < num_frames_evaluate; i++) {
|
||||
int32 t = num_frames_output + i;
|
||||
SubVector<BaseFloat> feat(feats, i);
|
||||
feature_pipeline_.GetFrame(t, &feat);
|
||||
}
|
||||
}
|
||||
/****** End locking of feature pipeline mutex. ******/
|
||||
feature_pipeline_mutex_.Unlock();
|
||||
|
||||
CuMatrix<BaseFloat> cu_loglikes;
|
||||
|
||||
if (feature_buffer_.empty()) {
|
||||
if (feature_buffer_finished_) {
|
||||
if (feats.NumRows() == 0) {
|
||||
if (features_done) {
|
||||
// flush out the last few frames. Note: this is the only place from
|
||||
// which we check feature_buffer_finished_, and we'll exit the loop, so
|
||||
// if we reach here it must be the first time it was true.
|
||||
last_time = true;
|
||||
if (!feature_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
computer.Flush(&cu_loglikes);
|
||||
ProcessLoglikes(log_inv_prior, &cu_loglikes);
|
||||
} else {
|
||||
// there is nothing to do because there is no input. Next call to Lock
|
||||
// should block till the feature-processing thread does something.
|
||||
if (!feature_synchronizer_.UnlockFailure(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
int32 num_frames_evaluate = std::min<int32>(feature_buffer_.size(),
|
||||
config_.nnet_batch_size),
|
||||
feat_dim = feature_buffer_[0]->Dim();
|
||||
Matrix<BaseFloat> feats(num_frames_evaluate, feat_dim);
|
||||
for (int32 i = 0; i < num_frames_evaluate; i++) {
|
||||
feats.Row(i).CopyFromVec(*(feature_buffer_[i]));
|
||||
delete feature_buffer_[i];
|
||||
}
|
||||
feature_buffer_.erase(feature_buffer_.begin(),
|
||||
feature_buffer_.begin() + num_frames_evaluate);
|
||||
feature_buffer_start_frame_ += num_frames_evaluate;
|
||||
if (!feature_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
|
||||
return false;
|
||||
|
||||
CuMatrix<BaseFloat> cu_feats;
|
||||
cu_feats.Swap(&feats); // If we don't have a GPU (and not having a GPU is
|
||||
// the normal expected use-case for this code),
|
||||
|
@ -555,7 +500,7 @@ bool SingleUtteranceNnet2DecoderThreaded::RunNnetEvaluationInternal() {
|
|||
computer.Compute(cu_feats, &cu_loglikes);
|
||||
ProcessLoglikes(log_inv_prior, &cu_loglikes);
|
||||
}
|
||||
|
||||
|
||||
Matrix<BaseFloat> loglikes;
|
||||
loglikes.Swap(&cu_loglikes); // If we don't have a GPU (and not having a
|
||||
// GPU is the normal expected use-case for
|
||||
|
@ -633,6 +578,12 @@ bool SingleUtteranceNnet2DecoderThreaded::RunDecoderSearchInternal() {
|
|||
decoder_mutex_.Lock();
|
||||
decoder_.AdvanceDecoding(&decodable_, config_.decode_batch_size);
|
||||
num_frames_decoded = decoder_.NumFramesDecoded();
|
||||
if (silence_weighting_.Active()) {
|
||||
silence_weighting_mutex_.Lock();
|
||||
// the next function does not trace back all the way; it's very fast.
|
||||
silence_weighting_.ComputeCurrentTraceback(decoder_);
|
||||
silence_weighting_mutex_.Unlock();
|
||||
}
|
||||
decoder_mutex_.Unlock();
|
||||
num_frames_decoded_ = num_frames_decoded;
|
||||
if (!decodable_synchronizer_.UnlockSuccess(ThreadSynchronizer::kConsumer))
|
||||
|
|
|
@ -298,18 +298,12 @@ class SingleUtteranceNnet2DecoderThreaded {
|
|||
// from Wait(). If called twice it is not an error.
|
||||
void WaitForAllThreads();
|
||||
|
||||
|
||||
// this function runs the thread that does the feature extraction; ptr_in is
|
||||
// to class SingleUtteranceNnet2DecoderThreaded. Always returns NULL, but in
|
||||
// case of failure, calls ptr_in->AbortAllThreads(true).
|
||||
static void* RunFeatureExtraction(void *ptr_in);
|
||||
// member-function version of RunFeatureExtraction, called by RunFeatureExgtraction.
|
||||
bool RunFeatureExtractionInternal();
|
||||
|
||||
|
||||
// this function runs the thread that does the neural-net evaluation ptr_in is
|
||||
// to class SingleUtteranceNnet2DecoderThreaded. Always returns NULL, but in
|
||||
// case of failure, calls ptr_in->AbortAllThreads(true).
|
||||
// this function runs the thread that does the feature extraction and
|
||||
// neural-net evaluation. ptr_in is to class
|
||||
// SingleUtteranceNnet2DecoderThreaded. Always returns NULL, but in case of
|
||||
// failure, calls ptr_in->AbortAllThreads(true).
|
||||
static void* RunNnetEvaluation(void *ptr_in);
|
||||
// member-function version of RunNnetEvaluation, called by RunNnetEvaluation.
|
||||
bool RunNnetEvaluationInternal();
|
||||
|
@ -317,6 +311,11 @@ class SingleUtteranceNnet2DecoderThreaded {
|
|||
// takes the log and subtracts the prior.
|
||||
void ProcessLoglikes(const CuVector<BaseFloat> &log_inv_prior,
|
||||
CuMatrixBase<BaseFloat> *loglikes);
|
||||
// called from RunNnetEvaluationInternal(). Returns true in the normal case,
|
||||
// false on error; if it returns false, then we expect that the calling thread
|
||||
// will terminate. This assumes the caller has already
|
||||
// locked feature_pipeline_mutex_.
|
||||
bool FeatureComputation(int32 num_frames_output);
|
||||
|
||||
|
||||
// this function runs the thread that does the neural-net evaluation ptr_in is
|
||||
|
@ -349,28 +348,21 @@ class SingleUtteranceNnet2DecoderThreaded {
|
|||
// called from the main thread, the main thread sets input_finished_ = true.
|
||||
// sampling_rate_ is only needed for checking that it matches the config.
|
||||
bool input_finished_;
|
||||
std::vector< Vector<BaseFloat>* > input_waveform_;
|
||||
std::deque< Vector<BaseFloat>* > input_waveform_;
|
||||
ThreadSynchronizer waveform_synchronizer_;
|
||||
|
||||
// feature_pipeline_ is only accessed by the feature-processing thread, and by
|
||||
// the main thread if GetAdaptionState() is called. It is guarded by feature_pipeline_mutex_.
|
||||
// feature_pipeline_ is accessed by the nnet-evaluation thread, by the main
|
||||
// thread if GetAdaptionState() is called, and by the decoding thread via
|
||||
// ComputeCurrentTraceback() if online silence weighting is being used. It is
|
||||
// guarded by feature_pipeline_mutex_.
|
||||
OnlineNnet2FeaturePipeline feature_pipeline_;
|
||||
Mutex feature_pipeline_mutex_;
|
||||
|
||||
// The following three variables are guarded by feature_synchronizer_; they are
|
||||
// a buffer of features that is passed between the feature-processing thread
|
||||
// and the neural-net evaluation thread. feature_buffer_start_frame_ is the
|
||||
// frame index of the first feature vector in the buffer. After the
|
||||
// neural-net evaluation thread reads the features, it removes them from the
|
||||
// vector "feature_buffer_" and advances "feature_buffer_start_frame_"
|
||||
// appropriately. The feature-processing thread does not advance
|
||||
// feature_buffer_start_frame_, but appends to feature_buffer_. After
|
||||
// all features are done (because user called InputFinished()), the
|
||||
// feature-processing thread sets feature_buffer_finished_.
|
||||
int32 feature_buffer_start_frame_;
|
||||
bool feature_buffer_finished_;
|
||||
std::vector<Vector<BaseFloat>* > feature_buffer_;
|
||||
ThreadSynchronizer feature_synchronizer_;
|
||||
|
||||
// This object is used to control the (optional) downweighting of silence in iVector estimation,
|
||||
// which is based on the decoder traceback.
|
||||
OnlineSilenceWeighting silence_weighting_;
|
||||
Mutex silence_weighting_mutex_;
|
||||
|
||||
|
||||
// this Decodable object just stores a matrix of scaled log-likelihoods
|
||||
// obtained by the nnet-evaluation thread. It is produced by the
|
||||
|
@ -393,10 +385,10 @@ class SingleUtteranceNnet2DecoderThreaded {
|
|||
// GetLattice() and GetBestPath().
|
||||
Mutex decoder_mutex_;
|
||||
|
||||
// This contains the thread pointers for the feature-extraction,
|
||||
// nnet-evaluation, and decoder-search threads respectively (or NULL if they
|
||||
// have been joined in Wait()).
|
||||
pthread_t threads_[3];
|
||||
// This contains the thread pointers for the nnet-evaluation and
|
||||
// decoder-search threads respectively (or NULL if they have been joined in
|
||||
// Wait()).
|
||||
pthread_t threads_[2];
|
||||
|
||||
// This is set to true if AbortAllThreads was called for any reason, including
|
||||
// if someone called TerminateDecoding().
|
||||
|
|
|
@ -104,6 +104,8 @@ class SingleUtteranceNnet2Decoder {
|
|||
/// with the required arguments.
|
||||
bool EndpointDetected(const OnlineEndpointConfig &config);
|
||||
|
||||
const LatticeFasterOnlineDecoder &Decoder() const { return decoder_; }
|
||||
|
||||
~SingleUtteranceNnet2Decoder() { }
|
||||
private:
|
||||
|
||||
|
|
|
@ -23,7 +23,8 @@
|
|||
namespace kaldi {
|
||||
|
||||
OnlineNnet2FeaturePipelineInfo::OnlineNnet2FeaturePipelineInfo(
|
||||
const OnlineNnet2FeaturePipelineConfig &config) {
|
||||
const OnlineNnet2FeaturePipelineConfig &config):
|
||||
silence_weighting_config(config.silence_weighting_config) {
|
||||
if (config.feature_type == "mfcc" || config.feature_type == "plp" ||
|
||||
config.feature_type == "fbank") {
|
||||
feature_type = config.feature_type;
|
||||
|
@ -167,6 +168,12 @@ void OnlineNnet2FeaturePipeline::AcceptWaveform(
|
|||
pitch_->AcceptWaveform(sampling_rate, waveform);
|
||||
}
|
||||
|
||||
void OnlineNnet2FeaturePipeline::UpdateFrameWeights(
|
||||
const std::vector<std::pair<int32, BaseFloat> > &delta_weights) {
|
||||
if (ivector_feature_ != NULL)
|
||||
ivector_feature_->UpdateFrameWeights(delta_weights);
|
||||
}
|
||||
|
||||
void OnlineNnet2FeaturePipeline::InputFinished() {
|
||||
base_feature_->InputFinished();
|
||||
if (pitch_)
|
||||
|
|
|
@ -80,6 +80,11 @@ struct OnlineNnet2FeaturePipelineConfig {
|
|||
// OnlineIvectorExtractionConfig.
|
||||
std::string ivector_extraction_config;
|
||||
|
||||
// Config that relates to how we weight silence for (ivector) adaptation
|
||||
// this is registered directly to the command line as you might want to
|
||||
// play with it in test time.
|
||||
OnlineSilenceWeightingConfig silence_weighting_config;
|
||||
|
||||
OnlineNnet2FeaturePipelineConfig():
|
||||
feature_type("mfcc"), add_pitch(false) { }
|
||||
|
||||
|
@ -101,6 +106,7 @@ struct OnlineNnet2FeaturePipelineConfig {
|
|||
po->Register("ivector-extraction-config", &ivector_extraction_config,
|
||||
"Configuration file for online iVector extraction, "
|
||||
"see class OnlineIvectorExtractionConfig in the code");
|
||||
silence_weighting_config.RegisterWithPrefix("ivector-silence-weighting", po);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -141,6 +147,13 @@ struct OnlineNnet2FeaturePipelineInfo {
|
|||
bool use_ivectors;
|
||||
OnlineIvectorExtractionInfo ivector_extractor_info;
|
||||
|
||||
// Config for weighting silence in iVector adaptation.
|
||||
// We declare this outside of ivector_extractor_info... it was
|
||||
// just easier to set up the code that way; and also we think
|
||||
// it's the kind of thing you might want to play with directly
|
||||
// on the command line instead of inside sub-config-files.
|
||||
OnlineSilenceWeightingConfig silence_weighting_config;
|
||||
|
||||
int32 IvectorDim() { return ivector_extractor_info.extractor.IvectorDim(); }
|
||||
private:
|
||||
KALDI_DISALLOW_COPY_AND_ASSIGN(OnlineNnet2FeaturePipelineInfo);
|
||||
|
@ -202,7 +215,13 @@ class OnlineNnet2FeaturePipeline: public OnlineFeatureInterface {
|
|||
/// to assert it equals what's in the config.
|
||||
void AcceptWaveform(BaseFloat sampling_rate,
|
||||
const VectorBase<BaseFloat> &waveform);
|
||||
|
||||
|
||||
/// This is used in case you are downweighting silence in the iVector
|
||||
/// estimation using the decoder traceback.
|
||||
void UpdateFrameWeights(
|
||||
const std::vector<std::pair<int32, BaseFloat> > &delta_weights);
|
||||
|
||||
|
||||
BaseFloat FrameShiftInSeconds() const { return info_.FrameShiftInSeconds(); }
|
||||
|
||||
/// If you call InputFinished(), it tells the class you won't be providing any
|
||||
|
|
|
@ -194,6 +194,10 @@ int main(int argc, char *argv[]) {
|
|||
|
||||
OnlineNnet2FeaturePipeline feature_pipeline(feature_info);
|
||||
feature_pipeline.SetAdaptationState(adaptation_state);
|
||||
|
||||
OnlineSilenceWeighting silence_weighting(
|
||||
trans_model,
|
||||
feature_info.silence_weighting_config);
|
||||
|
||||
SingleUtteranceNnet2Decoder decoder(nnet2_decoding_config,
|
||||
trans_model,
|
||||
|
@ -212,6 +216,8 @@ int main(int argc, char *argv[]) {
|
|||
}
|
||||
|
||||
int32 samp_offset = 0;
|
||||
std::vector<std::pair<int32, BaseFloat> > delta_weights;
|
||||
|
||||
while (samp_offset < data.Dim()) {
|
||||
int32 samp_remaining = data.Dim() - samp_offset;
|
||||
int32 num_samp = chunk_length < samp_remaining ? chunk_length
|
||||
|
@ -219,13 +225,21 @@ int main(int argc, char *argv[]) {
|
|||
|
||||
SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
|
||||
feature_pipeline.AcceptWaveform(samp_freq, wave_part);
|
||||
|
||||
|
||||
samp_offset += num_samp;
|
||||
decoding_timer.WaitUntil(samp_offset / samp_freq);
|
||||
if (samp_offset == data.Dim()) {
|
||||
// no more input. flush out last frames
|
||||
feature_pipeline.InputFinished();
|
||||
}
|
||||
|
||||
if (silence_weighting.Active()) {
|
||||
silence_weighting.ComputeCurrentTraceback(decoder.Decoder());
|
||||
silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(),
|
||||
&delta_weights);
|
||||
feature_pipeline.UpdateFrameWeights(delta_weights);
|
||||
}
|
||||
|
||||
decoder.AdvanceDecoding();
|
||||
|
||||
if (do_endpointing && decoder.EndpointDetected(endpoint_config))
|
||||
|
|
|
@ -105,7 +105,8 @@ int main(int argc, char *argv[]) {
|
|||
|
||||
BaseFloat chunk_length_secs = 0.05;
|
||||
bool do_endpointing = false;
|
||||
bool modify_ivector_config = false;
|
||||
bool modify_ivector_config = false;
|
||||
bool simulate_realtime_decoding = true;
|
||||
|
||||
po.Register("chunk-length", &chunk_length_secs,
|
||||
"Length of chunk size in seconds, that we provide each time to the "
|
||||
|
@ -121,6 +122,10 @@ int main(int argc, char *argv[]) {
|
|||
"This will give the best possible results, but the results may become dependent "
|
||||
"on the speed of your machine (slower machine -> better results). Compare "
|
||||
"to the --online option in online2-wav-nnet2-latgen-faster");
|
||||
po.Register("simulate-realtime-decoding", &simulate_realtime_decoding,
|
||||
"If true, simulate real-time decoding scenario by providing the "
|
||||
"data incrementally, calling sleep() until each piece is ready. "
|
||||
"If false, don't sleep (so it will be faster).");
|
||||
|
||||
feature_config.Register(&po);
|
||||
nnet2_decoding_config.Register(&po);
|
||||
|
@ -166,6 +171,7 @@ int main(int argc, char *argv[]) {
|
|||
int32 num_done = 0, num_err = 0;
|
||||
double tot_like = 0.0;
|
||||
int64 num_frames = 0;
|
||||
Timer global_timer;
|
||||
|
||||
SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
|
||||
RandomAccessTableReader<WaveHolder> wav_reader(wav_rspecifier);
|
||||
|
@ -213,8 +219,11 @@ int main(int argc, char *argv[]) {
|
|||
decoder.AcceptWaveform(samp_freq, wave_part);
|
||||
|
||||
samp_offset += num_samp;
|
||||
// Note: the next call may actually call sleep().
|
||||
decoding_timer.SleepUntil(samp_offset / samp_freq);
|
||||
|
||||
if (simulate_realtime_decoding) {
|
||||
// Note: the next call may actually call sleep().
|
||||
decoding_timer.SleepUntil(samp_offset / samp_freq);
|
||||
}
|
||||
if (samp_offset == data.Dim()) {
|
||||
// no more input. flush out last frames
|
||||
decoder.InputFinished();
|
||||
|
@ -227,8 +236,10 @@ int main(int argc, char *argv[]) {
|
|||
}
|
||||
Timer timer;
|
||||
decoder.Wait();
|
||||
KALDI_VLOG(1) << "Waited " << timer.Elapsed() << " seconds for decoder to "
|
||||
<< "finish after giving it last chunk.";
|
||||
if (simulate_realtime_decoding) {
|
||||
KALDI_VLOG(1) << "Waited " << timer.Elapsed() << " seconds for decoder to "
|
||||
<< "finish after giving it last chunk.";
|
||||
}
|
||||
decoder.FinalizeDecoding();
|
||||
|
||||
CompactLattice clat;
|
||||
|
@ -249,9 +260,10 @@ int main(int argc, char *argv[]) {
|
|||
1.0 / nnet2_decoding_config.acoustic_scale;
|
||||
ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat);
|
||||
|
||||
KALDI_VLOG(1) << "Adding the various end-of-utterance tasks took the "
|
||||
<< "total latency to " << timer.Elapsed() << " seconds.";
|
||||
|
||||
if (simulate_realtime_decoding) {
|
||||
KALDI_VLOG(1) << "Adding the various end-of-utterance tasks took the "
|
||||
<< "total latency to " << timer.Elapsed() << " seconds.";
|
||||
}
|
||||
clat_writer.Write(utt, clat);
|
||||
KALDI_LOG << "Decoded utterance " << utt;
|
||||
|
||||
|
@ -261,7 +273,17 @@ int main(int argc, char *argv[]) {
|
|||
}
|
||||
}
|
||||
bool online = true;
|
||||
timing_stats.Print(online);
|
||||
|
||||
if (simulate_realtime_decoding) {
|
||||
timing_stats.Print(online);
|
||||
} else {
|
||||
BaseFloat frame_shift = 0.01;
|
||||
BaseFloat real_time_factor =
|
||||
global_timer.Elapsed() / (frame_shift * num_frames);
|
||||
if (num_frames > 0)
|
||||
KALDI_LOG << "Real-time factor was " << real_time_factor
|
||||
<< " assuming frame shift of " << frame_shift;
|
||||
}
|
||||
|
||||
KALDI_LOG << "Decoded " << num_done << " utterances, "
|
||||
<< num_err << " with errors.";
|
||||
|
|
|
@ -170,6 +170,11 @@ void TestConvertStringToReal() {
|
|||
KALDI_ASSERT(!ConvertStringToReal("-1f", &d));
|
||||
KALDI_ASSERT(ConvertStringToReal("12345.2", &d) && fabs(d-12345.2) < 1.0);
|
||||
KALDI_ASSERT(ConvertStringToReal("1.0e+08", &d) && fabs(d-1.0e+08) < 100.0);
|
||||
|
||||
// it also works for inf or nan.
|
||||
KALDI_ASSERT(ConvertStringToReal("inf", &d) && d > 0 && d - d != 0);
|
||||
KALDI_ASSERT(ConvertStringToReal("nan", &d) && d != d);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -133,6 +133,7 @@ bool ConvertStringToInteger(const std::string &str,
|
|||
/// ConvertStringToReal converts a string into either float or double via strtod,
|
||||
/// and returns false if there was any kind of problem (i.e. the string was not a
|
||||
/// floating point number or contained extra non-whitespace junk.
|
||||
/// Be careful- this function will successfully read inf's or nan's.
|
||||
bool ConvertStringToReal(const std::string &str,
|
||||
double *out);
|
||||
bool ConvertStringToReal(const std::string &str,
|
||||
|
|
|
@ -4,7 +4,8 @@ CXX = g++
|
|||
# CXX = clang++ # Uncomment this line to build with Clang.
|
||||
|
||||
OPENFST_VERSION = 1.3.4
|
||||
# OPENFST_VERSION = 1.4.1 # Uncomment this line to build with OpenFst-1.4.1.
|
||||
# Uncomment the next line to build with OpenFst-1.4.1.
|
||||
# OPENFST_VERSION = 1.4.1
|
||||
# Note: OpenFst >= 1.4 requires C++11 support, hence you will need to use a
|
||||
# relatively recent C++ compiler, e.g. gcc >= 4.6, clang >= 3.0.
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче