зеркало из https://github.com/mozilla/kaldi.git
trunk: various script extensions relating to vad.scp; extensions to gmm-global-init-from-feats.cc for UBM initialization (multi-threaded operation; mixture splitting enabled)
git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@3229 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
This commit is contained in:
Родитель
38c0e03b4b
Коммит
feefd031ab
|
@ -43,7 +43,7 @@ texts=""
|
|||
|
||||
nu=`cat $data/utt2spk | wc -l`
|
||||
nf=`cat $data/feats.scp | wc -l`
|
||||
nt=`cat $data/text | wc -l`
|
||||
nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file
|
||||
if [ $nu -ne $nf ]; then
|
||||
echo "split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); this script "
|
||||
echo " may produce incorrectly split data."
|
||||
|
@ -61,7 +61,7 @@ if [ ! -d $s1 ]; then
|
|||
else
|
||||
need_to_split=false
|
||||
for f in utt2spk spk2utt feats.scp text wav.scp cmvn.scp spk2gender \
|
||||
segments reco2file_and_channel; do
|
||||
vad.scp segments reco2file_and_channel; do
|
||||
if [[ -f $data/$f && ( ! -f $s1/$f || $s1/$f -ot $data/$f ) ]]; then
|
||||
need_to_split=true
|
||||
fi
|
||||
|
@ -75,6 +75,7 @@ fi
|
|||
for n in `seq $numsplit`; do
|
||||
mkdir -p $data/split$numsplit/$n
|
||||
feats="$feats $data/split$numsplit/$n/feats.scp"
|
||||
vads="$vads $data/split$numsplit/$n/vad.scp"
|
||||
texts="$texts $data/split$numsplit/$n/text"
|
||||
utt2spks="$utt2spks $data/split$numsplit/$n/utt2spk"
|
||||
done
|
||||
|
@ -88,8 +89,10 @@ fi
|
|||
utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1
|
||||
|
||||
utils/split_scp.pl $utt2spk_opt $data/feats.scp $feats || exit 1
|
||||
[ -f $data/text ] && \
|
||||
utils/split_scp.pl $utt2spk_opt $data/text $texts
|
||||
|
||||
[ -f $data/text ] && utils/split_scp.pl $utt2spk_opt $data/text $texts
|
||||
|
||||
[ -f $data/vad.scp ] && utils/split_scp.pl $utt2spk_opt $data/vad.scp $vads
|
||||
|
||||
# If lockfile is not installed, just don't lock it. It's not a big deal.
|
||||
which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2010-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey)
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
# 2012-2013 Johns Hopkins University (Author: Daniel Povey)
|
||||
# Apache 2.0
|
||||
|
||||
|
||||
|
@ -89,6 +90,7 @@ fi
|
|||
function do_filtering {
|
||||
# assumes the utt2spk and spk2utt files already exist.
|
||||
[ -f $srcdir/feats.scp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/feats.scp >$destdir/feats.scp
|
||||
[ -f $srcdir/vad.scp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/vad.scp >$destdir/vad.scp
|
||||
[ -f $srcdir/wav.scp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/wav.scp >$destdir/wav.scp
|
||||
[ -f $srcdir/text ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/text >$destdir/text
|
||||
[ -f $srcdir/spk2gender ] && utils/filter_scp.pl $destdir/spk2utt <$srcdir/spk2gender >$destdir/spk2gender
|
||||
|
|
|
@ -14,8 +14,8 @@ OBJFILES = diag-gmm.o diag-gmm-normal.o mle-diag-gmm.o am-diag-gmm.o \
|
|||
|
||||
LIBNAME = kaldi-gmm
|
||||
|
||||
ADDLIBS = ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \
|
||||
../base/kaldi-base.a
|
||||
ADDLIBS = ../tree/kaldi-tree.a ../thread/kaldi-thread.a ../util/kaldi-util.a \
|
||||
../matrix/kaldi-matrix.a ../base/kaldi-base.a
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -372,6 +372,31 @@ UnitTestEstimateDiagGmm() {
|
|||
test_io(*gmm, est_gmm, true, feats); // Binary mode
|
||||
}
|
||||
|
||||
{ // Test multi-threaded update.
|
||||
GmmFlagsType flags_all = kGmmAll;
|
||||
est_gmm.Resize(gmm->NumGauss(),
|
||||
gmm->Dim(), flags_all);
|
||||
est_gmm.SetZero(flags_all);
|
||||
|
||||
Vector<BaseFloat> weights(counter);
|
||||
for (size_t i = 0; i < counter; i++)
|
||||
weights(i) = 0.5 + 0.1 * (rand() % 10);
|
||||
|
||||
|
||||
float loglike = 0.0;
|
||||
for (size_t i = 0; i < counter; i++) {
|
||||
loglike += weights(i) *
|
||||
est_gmm.AccumulateFromDiag(*gmm, feats.Row(i), weights(i));
|
||||
}
|
||||
AccumDiagGmm est_gmm2(*gmm, flags_all);
|
||||
int32 num_threads = 2;
|
||||
float loglike2 =
|
||||
est_gmm2.AccumulateFromDiagMultiThreaded(*gmm, feats, weights, num_threads);
|
||||
AssertEqual(loglike, loglike2);
|
||||
est_gmm.AssertEqual(est_gmm2);
|
||||
}
|
||||
|
||||
|
||||
delete gmm;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// gmm/mle-diag-gmm.cc
|
||||
|
||||
// Copyright 2009-2012 Saarland University; Georg Stemmer; Jan Silovsky;
|
||||
// Copyright 2009-2013 Saarland University; Georg Stemmer; Jan Silovsky;
|
||||
// Microsoft Corporation; Yanmin Qian;
|
||||
// Johns Hopkins University (author: Daniel Povey);
|
||||
// Cisco Systems (author: Neha Agrawal)
|
||||
|
@ -26,6 +26,7 @@
|
|||
|
||||
#include "gmm/diag-gmm.h"
|
||||
#include "gmm/mle-diag-gmm.h"
|
||||
#include "thread/kaldi-thread.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
|
@ -202,7 +203,6 @@ BaseFloat AccumDiagGmm::AccumulateFromDiag(const DiagGmm &gmm,
|
|||
return log_like;
|
||||
}
|
||||
|
||||
|
||||
// Careful: this wouldn't be valid if it were used to update the
|
||||
// Gaussian weights.
|
||||
void AccumDiagGmm::SmoothStats(BaseFloat tau) {
|
||||
|
@ -478,5 +478,84 @@ void MapDiagGmmUpdate(const MapDiagGmmOptions &config,
|
|||
}
|
||||
|
||||
|
||||
class AccumulateMultiThreadedClass: public MultiThreadable {
|
||||
public:
|
||||
AccumulateMultiThreadedClass(const DiagGmm &diag_gmm,
|
||||
const MatrixBase<BaseFloat> &data,
|
||||
const VectorBase<BaseFloat> &frame_weights,
|
||||
AccumDiagGmm *accum,
|
||||
double *tot_like):
|
||||
diag_gmm_(diag_gmm), data_(data),
|
||||
frame_weights_(frame_weights), dest_accum_(accum),
|
||||
tot_like_ptr_(tot_like), tot_like_(0.0) { }
|
||||
AccumulateMultiThreadedClass(const AccumulateMultiThreadedClass &other):
|
||||
diag_gmm_(other.diag_gmm_), data_(other.data_),
|
||||
frame_weights_(other.frame_weights_), dest_accum_(other.dest_accum_),
|
||||
accum_(diag_gmm_, dest_accum_->Flags()), tot_like_ptr_(other.tot_like_ptr_),
|
||||
tot_like_(0.0) {
|
||||
KALDI_ASSERT(data_.NumRows() == frame_weights_.Dim());
|
||||
}
|
||||
void operator () () {
|
||||
int32 num_frames = data_.NumRows(), num_threads = num_threads_,
|
||||
block_size = (num_frames + num_threads - 1) / num_threads,
|
||||
block_start = block_size * thread_id_,
|
||||
block_end = std::min(num_frames, block_start + block_size);
|
||||
tot_like_ = 0.0;
|
||||
double tot_weight = 0.0;
|
||||
for (int32 t = block_start; t < block_end; t++) {
|
||||
tot_like_ += frame_weights_(t) *
|
||||
accum_.AccumulateFromDiag(diag_gmm_, data_.Row(t), frame_weights_(t));
|
||||
tot_weight += frame_weights_(t);
|
||||
}
|
||||
KALDI_VLOG(3) << "Thread " << thread_id_ << " saw average likeliood/frame "
|
||||
<< (tot_like_ / tot_weight) << " over " << tot_weight
|
||||
<< " (weighted) frames.";
|
||||
}
|
||||
~AccumulateMultiThreadedClass() {
|
||||
if (accum_.Dim() != 0) { // if our accumulator is set up (this is not true
|
||||
// for the single object we use to initialize the others)
|
||||
dest_accum_->Add(1.0, accum_);
|
||||
*tot_like_ptr_ += tot_like_;
|
||||
}
|
||||
}
|
||||
private:
|
||||
const DiagGmm &diag_gmm_;
|
||||
const MatrixBase<BaseFloat> &data_;
|
||||
const VectorBase<BaseFloat> &frame_weights_;
|
||||
AccumDiagGmm *dest_accum_;
|
||||
AccumDiagGmm accum_;
|
||||
double *tot_like_ptr_;
|
||||
double tot_like_;
|
||||
};
|
||||
|
||||
|
||||
BaseFloat AccumDiagGmm::AccumulateFromDiagMultiThreaded(
|
||||
const DiagGmm &gmm,
|
||||
const MatrixBase<BaseFloat> &data,
|
||||
const VectorBase<BaseFloat> &frame_weights,
|
||||
int32 num_threads) {
|
||||
|
||||
double tot_like = 0.0;
|
||||
AccumulateMultiThreadedClass accumulator(gmm, data, frame_weights,
|
||||
this, &tot_like);
|
||||
{
|
||||
// Note: everything happens in the constructor and destructor of
|
||||
// the object created below.
|
||||
MultiThreader<AccumulateMultiThreadedClass> threader(num_threads,
|
||||
accumulator);
|
||||
// we need to make sure it's destroyed before we access the
|
||||
// value of tot_like.
|
||||
}
|
||||
return tot_like;
|
||||
}
|
||||
|
||||
void AccumDiagGmm::AssertEqual(const AccumDiagGmm &other) {
|
||||
KALDI_ASSERT(dim_ == other.dim_ && num_comp_ == other.num_comp_ &&
|
||||
flags_ == other.flags_);
|
||||
KALDI_ASSERT(occupancy_.ApproxEqual(other.occupancy_));
|
||||
KALDI_ASSERT(mean_accumulator_.ApproxEqual(other.mean_accumulator_));
|
||||
KALDI_ASSERT(variance_accumulator_.ApproxEqual(other.variance_accumulator_));
|
||||
}
|
||||
|
||||
|
||||
} // End of namespace kaldi
|
||||
|
|
|
@ -142,6 +142,16 @@ class AccumDiagGmm {
|
|||
const VectorBase<BaseFloat> &data,
|
||||
BaseFloat frame_posterior);
|
||||
|
||||
/// This does the same job as AccumulateFromDiag, but using
|
||||
/// multiple threads. Returns sum of (log-likelihood times
|
||||
/// frame weight) over all frames.
|
||||
BaseFloat AccumulateFromDiagMultiThreaded(
|
||||
const DiagGmm &gmm,
|
||||
const MatrixBase<BaseFloat> &data,
|
||||
const VectorBase<BaseFloat> &frame_weights,
|
||||
int32 num_threads);
|
||||
|
||||
|
||||
/// Increment the stats for this component by the specified amount
|
||||
/// (not all parts may be taken, depending on flags).
|
||||
/// Note: x_stats and x2_stats are assumed to already be multiplied by "occ"
|
||||
|
@ -173,7 +183,9 @@ class AccumDiagGmm {
|
|||
const VectorBase<double> &occupancy() const { return occupancy_; }
|
||||
const MatrixBase<double> &mean_accumulator() const { return mean_accumulator_; }
|
||||
const MatrixBase<double> &variance_accumulator() const { return variance_accumulator_; }
|
||||
|
||||
|
||||
// used in testing.
|
||||
void AssertEqual(const AccumDiagGmm &other);
|
||||
private:
|
||||
int32 dim_;
|
||||
int32 num_comp_;
|
||||
|
|
|
@ -61,13 +61,16 @@ void InitGmmFromRandomFrames(const Matrix<BaseFloat> &feats, DiagGmm *gmm) {
|
|||
void TrainOneIter(const Matrix<BaseFloat> &feats,
|
||||
const MleDiagGmmOptions &gmm_opts,
|
||||
int32 iter,
|
||||
int32 num_threads,
|
||||
DiagGmm *gmm) {
|
||||
AccumDiagGmm gmm_acc(*gmm, kGmmAll);
|
||||
|
||||
double tot_like = 0.0;
|
||||
|
||||
for (int32 t = 0; t < feats.NumRows(); t++)
|
||||
tot_like += gmm_acc.AccumulateFromDiag(*gmm, feats.Row(t), 1.0);
|
||||
Vector<BaseFloat> frame_weights(feats.NumRows(), kUndefined);
|
||||
frame_weights.Set(1.0);
|
||||
|
||||
double tot_like;
|
||||
tot_like = gmm_acc.AccumulateFromDiagMultiThreaded(*gmm, feats, frame_weights,
|
||||
num_threads);
|
||||
|
||||
KALDI_LOG << "Likelihood per frame on iteration " << iter
|
||||
<< " was " << (tot_like / feats.NumRows()) << " over "
|
||||
|
@ -97,17 +100,24 @@ int main(int argc, char *argv[]) {
|
|||
|
||||
bool binary = true;
|
||||
int32 num_gauss = 100;
|
||||
int32 num_gauss_init = 0;
|
||||
int32 num_iters = 50;
|
||||
int32 num_frames = 200000;
|
||||
int32 srand_seed = 0;
|
||||
int32 num_threads = 4;
|
||||
|
||||
po.Register("binary", &binary, "Write output in binary mode");
|
||||
po.Register("num-gauss", &num_gauss, "Number of Gaussians in the model");
|
||||
po.Register("num-gauss-init", &num_gauss_init, "Number of Gaussians in "
|
||||
"the model initially (if nonzero and less than num_gauss, "
|
||||
"we'll do mixture splitting)");
|
||||
po.Register("num-iters", &num_iters, "Number of iterations of training");
|
||||
po.Register("num-frames", &num_frames, "Number of feature vectors to store in "
|
||||
"memory and train on (randomly chosen from the input features)");
|
||||
po.Register("srand", &srand_seed, "Seed for random number generator ");
|
||||
|
||||
po.Register("num-threads", &num_threads, "Number of threads used for "
|
||||
"statistics accumulation");
|
||||
|
||||
gmm_opts.Register(&po);
|
||||
|
||||
po.Read(argc, argv);
|
||||
|
@ -132,7 +142,7 @@ int main(int argc, char *argv[]) {
|
|||
int64 num_read = 0, dim = 0;
|
||||
|
||||
KALDI_LOG << "Reading features (will keep " << num_frames << " frames.)";
|
||||
|
||||
|
||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||
const Matrix<BaseFloat> &this_feats = feature_reader.Value();
|
||||
for (int32 t = 0; t < this_feats.NumRows(); t++) {
|
||||
|
@ -160,15 +170,36 @@ int main(int argc, char *argv[]) {
|
|||
KALDI_WARN << "Number of frames read " << num_read << " was less than "
|
||||
<< "target number " << num_frames << ", using all we read.";
|
||||
feats.Resize(num_read, dim, kCopyData);
|
||||
} else {
|
||||
BaseFloat percent = num_frames * 100.0 / num_read;
|
||||
KALDI_LOG << "Kept " << num_frames << " out of " << num_read
|
||||
<< " input frames = " << percent << "%.";
|
||||
}
|
||||
|
||||
DiagGmm gmm(num_gauss, dim);
|
||||
|
||||
KALDI_LOG << "Initializing GMM means from random frames";
|
||||
InitGmmFromRandomFrames(feats, &gmm);
|
||||
if (num_gauss_init <= 0 || num_gauss_init > num_gauss)
|
||||
num_gauss_init = num_gauss;
|
||||
|
||||
for (int32 iter = 0; iter < num_iters; iter++)
|
||||
TrainOneIter(feats, gmm_opts, iter, &gmm);
|
||||
DiagGmm gmm(num_gauss_init, dim);
|
||||
|
||||
KALDI_LOG << "Initializing GMM means from random frames to "
|
||||
<< num_gauss_init << " Gaussians.";
|
||||
InitGmmFromRandomFrames(feats, &gmm);
|
||||
|
||||
// we'll increase the #Gaussians by splitting,
|
||||
// till halfway through training.
|
||||
int32 cur_num_gauss = num_gauss_init,
|
||||
gauss_inc = (num_gauss - num_gauss_init) / (num_iters / 2);
|
||||
|
||||
for (int32 iter = 0; iter < num_iters; iter++) {
|
||||
TrainOneIter(feats, gmm_opts, iter, num_threads, &gmm);
|
||||
|
||||
int32 next_num_gauss = std::min(num_gauss, cur_num_gauss + gauss_inc);
|
||||
if (next_num_gauss > cur_num_gauss) {
|
||||
KALDI_LOG << "Splitting to " << next_num_gauss << " Gaussians.";
|
||||
gmm.Split(next_num_gauss, 0.1);
|
||||
cur_num_gauss = next_num_gauss;
|
||||
}
|
||||
}
|
||||
|
||||
WriteKaldiObject(gmm, model_wxfilename, binary);
|
||||
KALDI_LOG << "Wrote model to " << model_wxfilename;
|
||||
|
|
Загрузка…
Ссылка в новой задаче