trunk: various script extensions relating to vad.scp; extensions to gmm-global-init-from-feats.cc for UBM initialization (multi-threaded operation; mixture splitting enabled)

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@3229 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
This commit is contained in:
Dan Povey 2013-11-28 05:52:41 +00:00
Родитель 38c0e03b4b
Коммит feefd031ab
7 изменённых файлов: 174 добавлений и 22 удалений

Просмотреть файл

@ -43,7 +43,7 @@ texts=""
nu=`cat $data/utt2spk | wc -l`
nf=`cat $data/feats.scp | wc -l`
nt=`cat $data/text | wc -l`
nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file
if [ $nu -ne $nf ]; then
echo "split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); this script "
echo " may produce incorrectly split data."
@ -61,7 +61,7 @@ if [ ! -d $s1 ]; then
else
need_to_split=false
for f in utt2spk spk2utt feats.scp text wav.scp cmvn.scp spk2gender \
segments reco2file_and_channel; do
vad.scp segments reco2file_and_channel; do
if [[ -f $data/$f && ( ! -f $s1/$f || $s1/$f -ot $data/$f ) ]]; then
need_to_split=true
fi
@ -75,6 +75,7 @@ fi
for n in `seq $numsplit`; do
mkdir -p $data/split$numsplit/$n
feats="$feats $data/split$numsplit/$n/feats.scp"
vads="$vads $data/split$numsplit/$n/vad.scp"
texts="$texts $data/split$numsplit/$n/text"
utt2spks="$utt2spks $data/split$numsplit/$n/utt2spk"
done
@ -88,8 +89,10 @@ fi
utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1
utils/split_scp.pl $utt2spk_opt $data/feats.scp $feats || exit 1
[ -f $data/text ] && \
utils/split_scp.pl $utt2spk_opt $data/text $texts
[ -f $data/text ] && utils/split_scp.pl $utt2spk_opt $data/text $texts
[ -f $data/vad.scp ] && utils/split_scp.pl $utt2spk_opt $data/vad.scp $vads
# If lockfile is not installed, just don't lock it. It's not a big deal.
which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock

Просмотреть файл

@ -1,5 +1,6 @@
#!/bin/bash
# Copyright 2010-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey)
# Copyright 2010-2011 Microsoft Corporation
# 2012-2013 Johns Hopkins University (Author: Daniel Povey)
# Apache 2.0
@ -89,6 +90,7 @@ fi
function do_filtering {
# assumes the utt2spk and spk2utt files already exist.
[ -f $srcdir/feats.scp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/feats.scp >$destdir/feats.scp
[ -f $srcdir/vad.scp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/vad.scp >$destdir/vad.scp
[ -f $srcdir/wav.scp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/wav.scp >$destdir/wav.scp
[ -f $srcdir/text ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/text >$destdir/text
[ -f $srcdir/spk2gender ] && utils/filter_scp.pl $destdir/spk2utt <$srcdir/spk2gender >$destdir/spk2gender

Просмотреть файл

@ -14,8 +14,8 @@ OBJFILES = diag-gmm.o diag-gmm-normal.o mle-diag-gmm.o am-diag-gmm.o \
LIBNAME = kaldi-gmm
ADDLIBS = ../tree/kaldi-tree.a ../util/kaldi-util.a ../matrix/kaldi-matrix.a \
../base/kaldi-base.a
ADDLIBS = ../tree/kaldi-tree.a ../thread/kaldi-thread.a ../util/kaldi-util.a \
../matrix/kaldi-matrix.a ../base/kaldi-base.a

Просмотреть файл

@ -372,6 +372,31 @@ UnitTestEstimateDiagGmm() {
test_io(*gmm, est_gmm, true, feats); // Binary mode
}
{ // Test multi-threaded update.
GmmFlagsType flags_all = kGmmAll;
est_gmm.Resize(gmm->NumGauss(),
gmm->Dim(), flags_all);
est_gmm.SetZero(flags_all);
Vector<BaseFloat> weights(counter);
for (size_t i = 0; i < counter; i++)
weights(i) = 0.5 + 0.1 * (rand() % 10);
float loglike = 0.0;
for (size_t i = 0; i < counter; i++) {
loglike += weights(i) *
est_gmm.AccumulateFromDiag(*gmm, feats.Row(i), weights(i));
}
AccumDiagGmm est_gmm2(*gmm, flags_all);
int32 num_threads = 2;
float loglike2 =
est_gmm2.AccumulateFromDiagMultiThreaded(*gmm, feats, weights, num_threads);
AssertEqual(loglike, loglike2);
est_gmm.AssertEqual(est_gmm2);
}
delete gmm;
}

Просмотреть файл

@ -1,6 +1,6 @@
// gmm/mle-diag-gmm.cc
// Copyright 2009-2012 Saarland University; Georg Stemmer; Jan Silovsky;
// Copyright 2009-2013 Saarland University; Georg Stemmer; Jan Silovsky;
// Microsoft Corporation; Yanmin Qian;
// Johns Hopkins University (author: Daniel Povey);
// Cisco Systems (author: Neha Agrawal)
@ -26,6 +26,7 @@
#include "gmm/diag-gmm.h"
#include "gmm/mle-diag-gmm.h"
#include "thread/kaldi-thread.h"
namespace kaldi {
@ -202,7 +203,6 @@ BaseFloat AccumDiagGmm::AccumulateFromDiag(const DiagGmm &gmm,
return log_like;
}
// Careful: this wouldn't be valid if it were used to update the
// Gaussian weights.
void AccumDiagGmm::SmoothStats(BaseFloat tau) {
@ -478,5 +478,84 @@ void MapDiagGmmUpdate(const MapDiagGmmOptions &config,
}
class AccumulateMultiThreadedClass: public MultiThreadable {
public:
AccumulateMultiThreadedClass(const DiagGmm &diag_gmm,
const MatrixBase<BaseFloat> &data,
const VectorBase<BaseFloat> &frame_weights,
AccumDiagGmm *accum,
double *tot_like):
diag_gmm_(diag_gmm), data_(data),
frame_weights_(frame_weights), dest_accum_(accum),
tot_like_ptr_(tot_like), tot_like_(0.0) { }
AccumulateMultiThreadedClass(const AccumulateMultiThreadedClass &other):
diag_gmm_(other.diag_gmm_), data_(other.data_),
frame_weights_(other.frame_weights_), dest_accum_(other.dest_accum_),
accum_(diag_gmm_, dest_accum_->Flags()), tot_like_ptr_(other.tot_like_ptr_),
tot_like_(0.0) {
KALDI_ASSERT(data_.NumRows() == frame_weights_.Dim());
}
void operator () () {
int32 num_frames = data_.NumRows(), num_threads = num_threads_,
block_size = (num_frames + num_threads - 1) / num_threads,
block_start = block_size * thread_id_,
block_end = std::min(num_frames, block_start + block_size);
tot_like_ = 0.0;
double tot_weight = 0.0;
for (int32 t = block_start; t < block_end; t++) {
tot_like_ += frame_weights_(t) *
accum_.AccumulateFromDiag(diag_gmm_, data_.Row(t), frame_weights_(t));
tot_weight += frame_weights_(t);
}
KALDI_VLOG(3) << "Thread " << thread_id_ << " saw average likeliood/frame "
<< (tot_like_ / tot_weight) << " over " << tot_weight
<< " (weighted) frames.";
}
~AccumulateMultiThreadedClass() {
if (accum_.Dim() != 0) { // if our accumulator is set up (this is not true
// for the single object we use to initialize the others)
dest_accum_->Add(1.0, accum_);
*tot_like_ptr_ += tot_like_;
}
}
private:
const DiagGmm &diag_gmm_;
const MatrixBase<BaseFloat> &data_;
const VectorBase<BaseFloat> &frame_weights_;
AccumDiagGmm *dest_accum_;
AccumDiagGmm accum_;
double *tot_like_ptr_;
double tot_like_;
};
BaseFloat AccumDiagGmm::AccumulateFromDiagMultiThreaded(
const DiagGmm &gmm,
const MatrixBase<BaseFloat> &data,
const VectorBase<BaseFloat> &frame_weights,
int32 num_threads) {
double tot_like = 0.0;
AccumulateMultiThreadedClass accumulator(gmm, data, frame_weights,
this, &tot_like);
{
// Note: everything happens in the constructor and destructor of
// the object created below.
MultiThreader<AccumulateMultiThreadedClass> threader(num_threads,
accumulator);
// we need to make sure it's destroyed before we access the
// value of tot_like.
}
return tot_like;
}
void AccumDiagGmm::AssertEqual(const AccumDiagGmm &other) {
KALDI_ASSERT(dim_ == other.dim_ && num_comp_ == other.num_comp_ &&
flags_ == other.flags_);
KALDI_ASSERT(occupancy_.ApproxEqual(other.occupancy_));
KALDI_ASSERT(mean_accumulator_.ApproxEqual(other.mean_accumulator_));
KALDI_ASSERT(variance_accumulator_.ApproxEqual(other.variance_accumulator_));
}
} // End of namespace kaldi

Просмотреть файл

@ -142,6 +142,16 @@ class AccumDiagGmm {
const VectorBase<BaseFloat> &data,
BaseFloat frame_posterior);
/// This does the same job as AccumulateFromDiag, but using
/// multiple threads. Returns sum of (log-likelihood times
/// frame weight) over all frames.
BaseFloat AccumulateFromDiagMultiThreaded(
const DiagGmm &gmm,
const MatrixBase<BaseFloat> &data,
const VectorBase<BaseFloat> &frame_weights,
int32 num_threads);
/// Increment the stats for this component by the specified amount
/// (not all parts may be taken, depending on flags).
/// Note: x_stats and x2_stats are assumed to already be multiplied by "occ"
@ -173,7 +183,9 @@ class AccumDiagGmm {
const VectorBase<double> &occupancy() const { return occupancy_; }
const MatrixBase<double> &mean_accumulator() const { return mean_accumulator_; }
const MatrixBase<double> &variance_accumulator() const { return variance_accumulator_; }
// used in testing.
void AssertEqual(const AccumDiagGmm &other);
private:
int32 dim_;
int32 num_comp_;

Просмотреть файл

@ -61,13 +61,16 @@ void InitGmmFromRandomFrames(const Matrix<BaseFloat> &feats, DiagGmm *gmm) {
void TrainOneIter(const Matrix<BaseFloat> &feats,
const MleDiagGmmOptions &gmm_opts,
int32 iter,
int32 num_threads,
DiagGmm *gmm) {
AccumDiagGmm gmm_acc(*gmm, kGmmAll);
double tot_like = 0.0;
for (int32 t = 0; t < feats.NumRows(); t++)
tot_like += gmm_acc.AccumulateFromDiag(*gmm, feats.Row(t), 1.0);
Vector<BaseFloat> frame_weights(feats.NumRows(), kUndefined);
frame_weights.Set(1.0);
double tot_like;
tot_like = gmm_acc.AccumulateFromDiagMultiThreaded(*gmm, feats, frame_weights,
num_threads);
KALDI_LOG << "Likelihood per frame on iteration " << iter
<< " was " << (tot_like / feats.NumRows()) << " over "
@ -97,17 +100,24 @@ int main(int argc, char *argv[]) {
bool binary = true;
int32 num_gauss = 100;
int32 num_gauss_init = 0;
int32 num_iters = 50;
int32 num_frames = 200000;
int32 srand_seed = 0;
int32 num_threads = 4;
po.Register("binary", &binary, "Write output in binary mode");
po.Register("num-gauss", &num_gauss, "Number of Gaussians in the model");
po.Register("num-gauss-init", &num_gauss_init, "Number of Gaussians in "
"the model initially (if nonzero and less than num_gauss, "
"we'll do mixture splitting)");
po.Register("num-iters", &num_iters, "Number of iterations of training");
po.Register("num-frames", &num_frames, "Number of feature vectors to store in "
"memory and train on (randomly chosen from the input features)");
po.Register("srand", &srand_seed, "Seed for random number generator ");
po.Register("num-threads", &num_threads, "Number of threads used for "
"statistics accumulation");
gmm_opts.Register(&po);
po.Read(argc, argv);
@ -132,7 +142,7 @@ int main(int argc, char *argv[]) {
int64 num_read = 0, dim = 0;
KALDI_LOG << "Reading features (will keep " << num_frames << " frames.)";
for (; !feature_reader.Done(); feature_reader.Next()) {
const Matrix<BaseFloat> &this_feats = feature_reader.Value();
for (int32 t = 0; t < this_feats.NumRows(); t++) {
@ -160,15 +170,36 @@ int main(int argc, char *argv[]) {
KALDI_WARN << "Number of frames read " << num_read << " was less than "
<< "target number " << num_frames << ", using all we read.";
feats.Resize(num_read, dim, kCopyData);
} else {
BaseFloat percent = num_frames * 100.0 / num_read;
KALDI_LOG << "Kept " << num_frames << " out of " << num_read
<< " input frames = " << percent << "%.";
}
DiagGmm gmm(num_gauss, dim);
KALDI_LOG << "Initializing GMM means from random frames";
InitGmmFromRandomFrames(feats, &gmm);
if (num_gauss_init <= 0 || num_gauss_init > num_gauss)
num_gauss_init = num_gauss;
for (int32 iter = 0; iter < num_iters; iter++)
TrainOneIter(feats, gmm_opts, iter, &gmm);
DiagGmm gmm(num_gauss_init, dim);
KALDI_LOG << "Initializing GMM means from random frames to "
<< num_gauss_init << " Gaussians.";
InitGmmFromRandomFrames(feats, &gmm);
// we'll increase the #Gaussians by splitting,
// till halfway through training.
int32 cur_num_gauss = num_gauss_init,
gauss_inc = (num_gauss - num_gauss_init) / (num_iters / 2);
for (int32 iter = 0; iter < num_iters; iter++) {
TrainOneIter(feats, gmm_opts, iter, num_threads, &gmm);
int32 next_num_gauss = std::min(num_gauss, cur_num_gauss + gauss_inc);
if (next_num_gauss > cur_num_gauss) {
KALDI_LOG << "Splitting to " << next_num_gauss << " Gaussians.";
gmm.Split(next_num_gauss, 0.1);
cur_num_gauss = next_num_gauss;
}
}
WriteKaldiObject(gmm, model_wxfilename, binary);
KALDI_LOG << "Wrote model to " << model_wxfilename;