зеркало из https://github.com/mozilla/kaldi.git
Adding the Neural Networks for phoneme-state classification
git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@51 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
This commit is contained in:
Родитель
2fc231f3b9
Коммит
0769ce57ee
|
@ -5,4 +5,7 @@ scripts for a sequence of experiments.
|
|||
s1: This setup is experiments with GMM-based systems with various
|
||||
Maximum Likelihood
|
||||
techniques including global and speaker-specific transforms.
|
||||
See a parallel setup in ../wsj/s1
|
||||
See a parallel setup in ../wsj/s1
|
||||
|
||||
s2: This setup is experiment with hybrid MLP system trained by
|
||||
stochastic gradient descent.
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
--use-energy=false # only non-default option.
|
|
@ -0,0 +1,22 @@
|
|||
<Topology>
|
||||
<TopologyEntry>
|
||||
<ForPhones>
|
||||
NONSILENCEPHONES
|
||||
</ForPhones>
|
||||
<State> 0 <PdfClass> 0 <Transition> 0 0.75 <Transition> 1 0.25 </State>
|
||||
<State> 1 <PdfClass> 1 <Transition> 1 0.75 <Transition> 2 0.25 </State>
|
||||
<State> 2 <PdfClass> 2 <Transition> 2 0.75 <Transition> 3 0.25 </State>
|
||||
<State> 3 </State>
|
||||
</TopologyEntry>
|
||||
<TopologyEntry>
|
||||
<ForPhones>
|
||||
SILENCEPHONES
|
||||
</ForPhones>
|
||||
<State> 0 <PdfClass> 0 <Transition> 0 0.25 <Transition> 1 0.25 <Transition> 2 0.25 <Transition> 3 0.25 </State>
|
||||
<State> 1 <PdfClass> 1 <Transition> 1 0.25 <Transition> 2 0.25 <Transition> 3 0.25 <Transition> 4 0.25 </State>
|
||||
<State> 2 <PdfClass> 2 <Transition> 1 0.25 <Transition> 2 0.25 <Transition> 3 0.25 <Transition> 4 0.25 </State>
|
||||
<State> 3 <PdfClass> 3 <Transition> 1 0.25 <Transition> 2 0.25 <Transition> 3 0.25 <Transition> 4 0.25 </State>
|
||||
<State> 4 <PdfClass> 4 <Transition> 4 0.25 <Transition> 5 0.75 </State>
|
||||
<State> 5 </State>
|
||||
</TopologyEntry>
|
||||
</Topology>
|
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# usage: make_trans.sh prefix in.flist input.snr out.txt out.scp
|
||||
|
||||
# prefix is first letters of the database "key" (rest are numeric)
|
||||
|
||||
# in.flist is just a list of filenames, probably of .sph files.
|
||||
# input.snr is an snr format file from the RM dataset.
|
||||
# out.txt is the output transcriptions in format "key word1 word\n"
|
||||
# out.scp is the output scp file, which is as in.scp but has the
|
||||
# database-key first on each line.
|
||||
|
||||
# Reads from first argument e.g. $rootdir/rm1_audio1/rm1/doc/al_sents.snr
|
||||
# and second argument train_wav.scp
|
||||
# Writes to standard output trans.txt
|
||||
|
||||
if(@ARGV != 5) {
|
||||
die "usage: make_trans.sh prefix in.flist input.snr out.txt out.scp\n";
|
||||
}
|
||||
($prefix, $in_flist, $input_snr, $out_txt, $out_scp) = @ARGV;
|
||||
|
||||
open(F, "<$input_snr") || die "Opening SNOR file $input_snr";
|
||||
|
||||
while(<F>) {
|
||||
if(m/^;/) { next; }
|
||||
m/(.+) \((.+)\)/ || die "bad line $_";
|
||||
$T{$2} = $1;
|
||||
}
|
||||
|
||||
close(F);
|
||||
open(G, "<$in_flist") || die "Opening file list $in_flist";
|
||||
|
||||
open(O, ">$out_txt") || die "Open output transcription file $out_txt";
|
||||
|
||||
open(P, ">$out_scp") || die "Open output scp file $out_scp";
|
||||
|
||||
while(<G>) {
|
||||
$_ =~ m:/(\w+)/(\w+)\.sph\s+$:i || die "bad scp line $_";
|
||||
$spkname = $1;
|
||||
$uttname = $2;
|
||||
$uttname =~ tr/a-z/A-Z/;
|
||||
defined $T{$uttname} || die "no trans for sent $uttname";
|
||||
$spkname =~ s/_//g; # remove underscore from spk name to make key nicer.
|
||||
$key = $prefix . "_" . $spkname . "_" . $uttname;
|
||||
$key =~ tr/A-Z/a-z/; # Make it all lower case.
|
||||
# to make the numerical and string-sorted orders the same.
|
||||
print O "$key $T{$uttname}\n";
|
||||
print P "$key $_";
|
||||
$n++;
|
||||
}
|
||||
close(O) || die "Closing output.";
|
||||
close(P) || die "Closing output.";
|
||||
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
# This script should be run from the directory where it is located (i.e. data_prep)
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# The input is the 3 CDs from the LDC distribution of Resource Management.
|
||||
# The script's argument is a directory which has three subdirectories:
|
||||
# rm1_audio1 rm1_audio2 rm2_audio
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "Usage: ./run.sh /path/to/RM"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
RMROOT=$1
|
||||
if [ ! -d $RMROOT/rm1_audio1 -o ! -d $RMROOT/rm1_audio2 ]; then
|
||||
echo "Error: run.sh requires a directory argument that contains rm1_audio1 and rm1_audio2"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ ! -d $RMROOT/rm2_audio ]; then
|
||||
echo "**Warning: $RMROOT/rm2_audio does not exist; won't create spk2gender.map file correctly***"
|
||||
sleep 1
|
||||
fi
|
||||
|
||||
(
|
||||
find $RMROOT/rm1_audio1/rm1/ind_trn -iname '*.sph';
|
||||
find $RMROOT/rm1_audio2/2_4_2/rm1/ind/dev_aug -iname '*.sph';
|
||||
) | perl -ane ' m:/sa\d.sph:i || m:/sb\d\d.sph:i || print; ' > train_sph.flist
|
||||
|
||||
|
||||
|
||||
# make_trans.pl also creates the utterance id's and the kaldi-format scp file.
|
||||
./make_trans.pl trn train_sph.flist $RMROOT/rm1_audio1/rm1/doc/al_sents.snr train_trans.txt train_sph.scp
|
||||
mv train_trans.txt tmp; sort -k 1 tmp > train_trans.txt
|
||||
mv train_sph.scp tmp; sort -k 1 tmp > train_sph.scp
|
||||
|
||||
sph2pipe=`cd ../../../..; echo $PWD/tools/sph2pipe_v2.5/sph2pipe`
|
||||
if [ ! -f $sph2pipe ]; then
|
||||
echo "Could not find the sph2pipe program at $sph2pipe";
|
||||
exit 1;
|
||||
fi
|
||||
awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < train_sph.scp > train_wav.scp
|
||||
|
||||
cat train_wav.scp | perl -ane 'm/^(\w+_(\w+)\w_\w+) / || die; print "$1 $2\n"' > train.utt2spk
|
||||
cat train.utt2spk | sort -k 2 | ../scripts/utt2spk_to_spk2utt.pl > train.spk2utt
|
||||
|
||||
|
||||
for ntest in 1_mar87 2_oct87 4_feb89 5_oct89 6_feb91 7_sep92; do
|
||||
n=`echo $ntest | cut -d_ -f 1`
|
||||
test=`echo $ntest | cut -d_ -f 2`
|
||||
root=$RMROOT/rm1_audio2/2_4_2
|
||||
for x in `grep -v ';' $root/rm1/doc/tests/$ntest/${n}_indtst.ndx`; do
|
||||
echo "$root/$x ";
|
||||
done > test_${test}_sph.flist
|
||||
done
|
||||
|
||||
# make_trans.pl also creates the utterance id's and the kaldi-format scp file.
|
||||
for test in mar87 oct87 feb89 oct89 feb91 sep92; do
|
||||
./make_trans.pl ${test} test_${test}_sph.flist $RMROOT/rm1_audio1/rm1/doc/al_sents.snr test_${test}_trans.txt test_${test}_sph.scp
|
||||
mv test_${test}_trans.txt tmp; sort -k 1 tmp > test_${test}_trans.txt
|
||||
mv test_${test}_sph.scp tmp; sort -k 1 tmp > test_${test}_sph.scp
|
||||
|
||||
awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < test_${test}_sph.scp > test_${test}_wav.scp
|
||||
|
||||
cat test_${test}_wav.scp | perl -ane 'm/^(\w+_(\w+)\w_\w+) / || die; print "$1 $2\n"' > test_${test}.utt2spk
|
||||
cat test_${test}.utt2spk | sort -k 2 | ../scripts/utt2spk_to_spk2utt.pl > test_${test}.spk2utt
|
||||
done
|
||||
|
||||
cat $RMROOT/rm1_audio2/2_5_1/rm1/doc/al_spkrs.txt \
|
||||
$RMROOT/rm2_audio/3-1.2/rm2/doc/al_spkrs.txt | \
|
||||
perl -ane 'tr/A-Z/a-z/;print;' | grep -v ';' | \
|
||||
awk '{print $1, $2}' > spk2gender.map
|
||||
|
||||
../scripts/make_rm_lm.pl $RMROOT/rm1_audio1/rm1/doc/wp_gram.txt > G.txt
|
||||
|
||||
# Getting lexicon
|
||||
../scripts/make_rm_dict.pl $RMROOT/rm1_audio2/2_4_2/score/src/rdev/pcdsril.txt > lexicon.txt
|
||||
|
||||
echo Succeeded.
|
|
@ -0,0 +1,39 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
fake=false
|
||||
if [ "$1" == "--fake" ]; then
|
||||
fake=true
|
||||
shift
|
||||
fi
|
||||
|
||||
sphdir=$1 # e.g. /mnt/matylda2/data/RM
|
||||
wavdir=$2 # e.g. /mnt/matylda6/jhu09/qpovey/kaldi_rm_wav
|
||||
flistin=$3 # e.g. train_sph.flist, contains sph files in sphdir
|
||||
flistout=$4 # e.g. train_wav.flist, contains wav files in wavdir
|
||||
|
||||
|
||||
if [ $fake == false ]; then
|
||||
for x in `cat $flistin`; do
|
||||
y=`echo $x | sed s:$sphdir:$wavdir: | sed s:.sph:.wav:`;
|
||||
mkdir -p `dirname $y`
|
||||
../../tools/sph2pipe_v2.5/sph2pipe -f wav $x $y || exit 1;
|
||||
done
|
||||
fi
|
||||
|
||||
cat $flistin | sed s:$sphdir:$wavdir: | sed s:.sph:.wav: > $flistout || exit 1;
|
||||
|
|
@ -0,0 +1 @@
|
|||
export PATH=$PATH:../../../src/bin:../../../tools/openfst/bin:../../../src/fstbin/:../../../src/gmmbin/:../../../src/featbin/:../../../src/fgmmbin:../../../src/sgmmbin:../../../src/nnetbin
|
|
@ -0,0 +1,71 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
#exit 1 # Don't run this... it's to be run line by line from the shell.
|
||||
|
||||
# This script file cannot be run as-is; some paths in it need to be changed
|
||||
# before you can run it.
|
||||
# Search for /path/to.
|
||||
# It is recommended that you do not invoke this file from the shell, but
|
||||
# run the paths one by one, by hand.
|
||||
|
||||
# the step in data_prep/ will need to be modified for your system.
|
||||
|
||||
# First step is to do data preparation:
|
||||
# This just creates some text files, it is fast.
|
||||
# If not on the BUT system, you would have to change run.sh to reflect
|
||||
# your own paths.
|
||||
#
|
||||
|
||||
#Example arguments to run.sh: /mnt/matylda2/data/RM, /ais/gobi2/speech/RM, /cygdrive/e/data/RM
|
||||
# RM is a directory with subdirectories rm1_audio1, rm1_audio2, rm2_audio
|
||||
cd data_prep
|
||||
#*** You have to change the pathname below.***
|
||||
./run.sh /path/to/RM
|
||||
cd ..
|
||||
|
||||
mkdir -p data
|
||||
( cd data; cp ../data_prep/{train,test*}.{spk2utt,utt2spk} . ; cp ../data_prep/spk2gender.map . )
|
||||
|
||||
# This next step converts the lexicon, grammar, etc., into FST format.
|
||||
steps/prepare_graphs.sh
|
||||
|
||||
|
||||
# Next, make sure that "exp/" is someplace you can write a significant amount of
|
||||
# data to (e.g. make it a link to a file on some reasonably large file system).
|
||||
# If it doesn't exist, the scripts below will make the directory "exp".
|
||||
|
||||
# mfcc should be set to some place to put training mfcc's
|
||||
# where you have space.
|
||||
#e.g.: mfccdir=/mnt/matylda6/jhu09/qpovey/kaldi_rm_mfccb
|
||||
mfccdir=/path/to/mfccdir
|
||||
steps/make_mfcc_train.sh $mfccdir
|
||||
steps/make_mfcc_test.sh $mfccdir
|
||||
|
||||
|
||||
|
||||
# first, we will train monophone GMM system to get training labels
|
||||
steps/train_mono.sh
|
||||
steps/decode_mono.sh &
|
||||
|
||||
|
||||
|
||||
# Now we train the MLP,
|
||||
# it will have CMVN normalized MFCCs as input and phoneme-state posteriors as output
|
||||
steps/train_nnet.sh
|
||||
#steps/decode_nnet.sh #TODO
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Adds some specified number of disambig symbols to a symbol table.
|
||||
# Adds these as #1, #2, etc.
|
||||
# If the --include-zero option is specified, includes an extra one
|
||||
# #0.
|
||||
if(!(@ARGV == 2 || (@ARGV ==3 && $ARGV[0] eq "--include-zero"))) {
|
||||
die "Usage: add_disambig.pl [--include-zero] symtab.txt num_extra > symtab_out.txt ";
|
||||
}
|
||||
|
||||
if(@ARGV == 3) {
|
||||
$include_zero = 1;
|
||||
$ARGV[0] eq "--include-zero" || die "Bad option/first argument $ARGV[0]";
|
||||
shift @ARGV;
|
||||
} else {
|
||||
$include_zero = 0;
|
||||
}
|
||||
|
||||
$input = $ARGV[0];
|
||||
$nsyms = $ARGV[1];
|
||||
|
||||
open(F, "<$input") || die "Opening file $input";
|
||||
|
||||
while(<F>) {
|
||||
@A = split(" ", $_);
|
||||
@A == 2 || die "Bad line $_";
|
||||
$lastsym = $A[1];
|
||||
print;
|
||||
}
|
||||
|
||||
if(!defined($lastsym)){
|
||||
die "Empty symbol file?";
|
||||
}
|
||||
|
||||
if($include_zero) {
|
||||
$lastsym++;
|
||||
print "#0 $lastsym\n";
|
||||
}
|
||||
|
||||
for($n = 1; $n <= $nsyms; $n++) {
|
||||
$y = $n + $lastsym;
|
||||
print "#$n $y\n";
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Adds disambiguation symbols to a lexicon.
|
||||
# Outputs still in the normal lexicon format.
|
||||
# Disambig syms are numbered #1, #2, #3, etc. (#0
|
||||
# reserved for symbol in grammar).
|
||||
# Outputs the number of disambig syms to the standard output.
|
||||
|
||||
if(@ARGV != 2) {
|
||||
die "Usage: add_lex_disambig.pl lexicon.txt lexicon_disambig.txt "
|
||||
}
|
||||
|
||||
|
||||
$lexfn = shift @ARGV;
|
||||
$lexoutfn = shift @ARGV;
|
||||
|
||||
open(L, "<$lexfn") || die "Error opening lexicon $lexfn";
|
||||
|
||||
# (1) Read in the lexicon.
|
||||
@L = ( );
|
||||
while(<L>) {
|
||||
@A = split(" ", $_);
|
||||
push @L, join(" ", @A);
|
||||
}
|
||||
|
||||
# (2) Work out the count of each phone-sequence in the
|
||||
# lexicon.
|
||||
|
||||
foreach $l (@L) {
|
||||
@A = split(" ", $l);
|
||||
shift @A; # Remove word.
|
||||
$count{join(" ",@A)}++;
|
||||
}
|
||||
|
||||
# (3) For each left sub-sequence of each phone-sequence, note down
|
||||
# that exists (for identifying prefixes of longer strings).
|
||||
|
||||
foreach $l (@L) {
|
||||
@A = split(" ", $l);
|
||||
shift @A; # Remove word.
|
||||
while(@A > 0) {
|
||||
pop @A; # Remove last phone
|
||||
$issubseq{join(" ",@A)} = 1;
|
||||
}
|
||||
}
|
||||
|
||||
# (4) For each entry in the lexicon:
|
||||
# if the phone sequence is unique and is not a
|
||||
# prefix of another word, no diambig symbol.
|
||||
# Else output #1, or #2, #3, ... if the same phone-seq
|
||||
# has already been assigned a disambig symbol.
|
||||
|
||||
|
||||
open(O, ">$lexoutfn") || die "Opening lexicon file $lexoutfn for writing.\n";
|
||||
|
||||
$max_disambig = 0;
|
||||
foreach $l (@L) {
|
||||
@A = split(" ", $l);
|
||||
$word = shift @A;
|
||||
$phnseq = join(" ",@A);
|
||||
if(!defined $issubseq{$phnseq}
|
||||
&& $count{$phnseq}==1) {
|
||||
; # Do nothing.
|
||||
} else {
|
||||
if($phnseq eq "") { # need disambig symbols for the empty string
|
||||
# that are not use anywhere else.
|
||||
$max_disambig++;
|
||||
$reserved{$max_disambig} = 1;
|
||||
$phnseq = "#$max_disambig";
|
||||
} else {
|
||||
$curnumber = $disambig_of{$phnseq};
|
||||
if(!defined{$curnumber}) { $curnumber = 0; }
|
||||
$curnumber++; # now 1 or 2, ...
|
||||
while(defined $reserved{$curnumber} ) { $curnumber++; } # skip over reserved symbols
|
||||
if($curnumber > $max_disambig) {
|
||||
$max_disambig = $curnumber;
|
||||
}
|
||||
$disambig_of{$phnseq} = $curnumber;
|
||||
$phnseq = $phnseq . " #" . $curnumber;
|
||||
}
|
||||
}
|
||||
print O "$word\t$phnseq\n";
|
||||
}
|
||||
|
||||
print $max_disambig . "\n";
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
#!/usr/bin/perl -w
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This script takes a list of utterance-ids and filters an scp
|
||||
# file (or any file whose first field is an utterance id), printing
|
||||
# out only those lines whose first field is in id_list.
|
||||
|
||||
if(@ARGV < 1 || @ARGV > 2) {
|
||||
die "Usage: filter_scp.pl id_list [in.scp] > out.scp ";
|
||||
}
|
||||
|
||||
$idlist = shift @ARGV;
|
||||
open(F, "<$idlist") || die "Could not open id-list file $idlist";
|
||||
while(<F>) {
|
||||
@A = split;
|
||||
@A>=1 || die "Invalid id-list file line $_";
|
||||
$seen{$A[0]} = 1;
|
||||
}
|
||||
|
||||
while(<>) {
|
||||
@A = split;
|
||||
@A > 0 || die "Invalid scp file line $_";
|
||||
if($seen{$A[0]}) {
|
||||
print $_;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
#!/usr/bin/python -u
|
||||
|
||||
# ./gen_hamm_dct.py
|
||||
# script generateing NN initialization for training with TNet
|
||||
#
|
||||
# author: Karel Vesely
|
||||
#
|
||||
|
||||
import math, random
|
||||
import sys
|
||||
|
||||
|
||||
from optparse import OptionParser
|
||||
|
||||
parser = OptionParser()
|
||||
parser.add_option('--dim', dest='dim', help='d1:d2:d3 layer dimensions in the network')
|
||||
parser.add_option('--gauss', dest='gauss', help='use gaussian noise for weights', action='store_true', default=False)
|
||||
parser.add_option('--negbias', dest='negbias', help='use uniform [-4.1,-3.9] for bias (defaultall 0.0)', action='store_true', default=False)
|
||||
parser.add_option('--inputscale', dest='inputscale', help='scale the weights by 3/sqrt(Ninputs)', action='store_true', default=False)
|
||||
parser.add_option('--linBNdim', dest='linBNdim', help='dim of linear bottleneck (sigmoids will be omitted, bias will be zero)',default=0)
|
||||
(options, args) = parser.parse_args()
|
||||
|
||||
if(options.dim == None):
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
dimStrL = options.dim.split(':')
|
||||
|
||||
dimL = []
|
||||
for i in range(len(dimStrL)):
|
||||
dimL.append(int(dimStrL[i]))
|
||||
|
||||
|
||||
#print dimL,'linBN',options.linBNdim
|
||||
|
||||
for layer in range(len(dimL)-1):
|
||||
print '<biasedlinearity>', dimL[layer+1], dimL[layer]
|
||||
#weight matrix
|
||||
print '['
|
||||
for row in range(dimL[layer+1]):
|
||||
for col in range(dimL[layer]):
|
||||
if(options.gauss):
|
||||
if(options.inputscale):
|
||||
print 3/math.sqrt(dimL[layer])*random.gauss(0.0,1.0),
|
||||
else:
|
||||
print 0.1*random.gauss(0.0,1.0),
|
||||
else:
|
||||
if(options.inputscale):
|
||||
print (random.random()-0.5)*2*3/math.sqrt(dimL[layer]),
|
||||
else:
|
||||
print random.random()/5.0-0.1,
|
||||
print #newline for each row
|
||||
print ']'
|
||||
#bias vector
|
||||
print '[',
|
||||
for idx in range(dimL[layer+1]):
|
||||
if(int(options.linBNdim) == dimL[layer+1]):
|
||||
print '0.0',
|
||||
elif(options.negbias):
|
||||
print random.random()/5.0-4.1,
|
||||
else:
|
||||
print '0.0',
|
||||
print ']'
|
||||
|
||||
if(int(options.linBNdim) != dimL[layer+1]):
|
||||
if(layer == len(dimL)-2):
|
||||
print '<softmax>', dimL[layer+1], dimL[layer+1]
|
||||
else:
|
||||
print '<sigmoid>', dimL[layer+1], dimL[layer+1]
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
$ignore_noninteger = 0;
|
||||
$ignore_first_field = 0;
|
||||
for($x = 0; $x < 2; $x++) {
|
||||
if($ARGV[0] eq "--ignore-noninteger") { $ignore_oov = 1; shift @ARGV; }
|
||||
if($ARGV[0] eq "--ignore-first-field") { $ignore_first_field = 1; shift @ARGV; }
|
||||
}
|
||||
|
||||
$symtab = shift @ARGV;
|
||||
if(!defined $symtab) {
|
||||
die "Usage: sym2int.pl symtab [input transcriptions] > output transcriptions\n";
|
||||
}
|
||||
open(F, "<$symtab") || die "Error opening symbol table file $symtab";
|
||||
while(<F>) {
|
||||
@A = split(" ", $_);
|
||||
@A == 2 || die "bad line in symbol table file: $_";
|
||||
$int2sym{$A[1]} = $A[0];
|
||||
}
|
||||
|
||||
$error = 0;
|
||||
while(<>) {
|
||||
@A = split(" ", $_);
|
||||
if(@A == 0) {
|
||||
die "Empty line in transcriptions input.";
|
||||
}
|
||||
if($ignore_first_field) {
|
||||
$key = shift @A;
|
||||
print $key . " ";
|
||||
}
|
||||
foreach $a (@A) {
|
||||
if($a !~ m:^\d+$:) { # not all digits..
|
||||
if($ignore_noninteger) {
|
||||
print $a . " ";
|
||||
next;
|
||||
} else {
|
||||
if($a eq $A[0]) {
|
||||
die "int2sym.pl: found noninteger token $a (try --ignore-first-field)\n";
|
||||
} else {
|
||||
die "int2sym.pl: found noninteger token $a (try --ignore-noninteger if valid input)\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
$s = $int2sym{$a};
|
||||
if(!defined ($s)) {
|
||||
die "int2sym.pl: integer $a not in symbol table $symtab.";
|
||||
}
|
||||
print $s . " ";
|
||||
}
|
||||
print "\n";
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Usage: is_sorted.sh [script-file]
|
||||
# This script returns 0 (success) if the script file argument [or standard input]
|
||||
# is sorted and 1 otherwise.
|
||||
|
||||
export LC_ALL=C
|
||||
|
||||
if [ $# == 0 ]; then
|
||||
scp=-
|
||||
fi
|
||||
if [ $# == 1 ]; then
|
||||
scp=$1
|
||||
fi
|
||||
if [ $# -gt 1 -o "$1" == "--help" -o "$1" == "-h" ]; then
|
||||
echo "Usage: is_sorted.sh [script-file]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cat $scp > /tmp/tmp1.$$
|
||||
sort /tmp/tmp1.$$ > /tmp/tmp2.$$
|
||||
cmp /tmp/tmp1.$$ /tmp/tmp2.$$ >/dev/null
|
||||
ret=$?
|
||||
rm /tmp/tmp1.$$ /tmp/tmp2.$$
|
||||
if [ $ret == 0 ]; then
|
||||
exit 0;
|
||||
else
|
||||
echo "is_sorted.sh: script file $scp is not sorted";
|
||||
exit 1;
|
||||
fi
|
|
@ -0,0 +1,112 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# makes lexicon FST (no pron-probs involved).
|
||||
|
||||
if(@ARGV != 1 && @ARGV != 3) {
|
||||
die "Usage: make_lexicon_fst.pl lexicon.txt [silprob silphone] > lexiconfst.txt"
|
||||
}
|
||||
|
||||
$lexfn = shift @ARGV;
|
||||
if(@ARGV == 0) {
|
||||
$silprob = 0.0;
|
||||
} else {
|
||||
($silprob,$silphone) = @ARGV;
|
||||
}
|
||||
if($silprob != 0.0) {
|
||||
$silprob < 1.0 || die "Sil prob cannot be >= 1.0";
|
||||
$silcost = -log($silprob);
|
||||
$nosilcost = -log(1.0 - $silprob);
|
||||
}
|
||||
|
||||
|
||||
open(L, "<$lexfn") || die "Error opening lexicon $lexfn";
|
||||
|
||||
|
||||
|
||||
if( $silprob == 0.0 ) { # No optional silences: just have one (loop+final) state which is numbered zero.
|
||||
$loopstate = 0;
|
||||
$nexststate = 1; # next unallocated state.
|
||||
while(<L>) {
|
||||
@A = split(" ", $_);
|
||||
$w = shift @A;
|
||||
if(@A == 0) { # For empty words (<s> and </s>) insert no optional
|
||||
# silence (not needed as adjacent words supply it)....
|
||||
# actually we only hit this case for the lexicon without disambig
|
||||
# symbols but doesn't ever matter as training transcripts don't have <s> or </s>.
|
||||
print "$loopstate\t$loopstate\t<eps>\t$w\n";
|
||||
} else {
|
||||
$s = $loopstate;
|
||||
$word_or_eps = $w;
|
||||
while (@A > 0) {
|
||||
$p = shift @A;
|
||||
if(@A > 0) {
|
||||
$ns = $nextstate++;
|
||||
} else {
|
||||
$ns = $loopstate;
|
||||
}
|
||||
print "$s\t$ns\t$p\t$word_or_eps\n";
|
||||
$word_or_eps = "<eps>";
|
||||
$s = $ns;
|
||||
}
|
||||
}
|
||||
}
|
||||
print "$loopstate\t0\n"; # final-cost.
|
||||
} else { # have silence probs.
|
||||
$startstate = 0;
|
||||
$loopstate = 1;
|
||||
$silstate = 2; # state from where we go to loopstate after emitting silence.
|
||||
$nextstate = 3;
|
||||
print "$startstate\t$loopstate\t<eps>\t<eps>\t$nosilcost\n"; # no silence.
|
||||
print "$startstate\t$loopstate\t$silphone\t<eps>\t$silcost\n"; # silence.
|
||||
print "$silstate\t$loopstate\t$silphone\t<eps>\n"; # no cost.
|
||||
while(<L>) {
|
||||
@A = split(" ", $_);
|
||||
$w = shift @A;
|
||||
if(@A == 0) { # For empty words (<s> and </s>) insert no optional
|
||||
# silence (not needed as adjacent words supply it)....
|
||||
# actually we only hit this case for the lexicon without disambig
|
||||
# symbols but doesn't ever matter as training transcripts don't have <s> or </s>.
|
||||
print "$loopstate\t$loopstate\t<eps>\t$w\n";
|
||||
} else {
|
||||
$is_silence_word = (@A == 1 && $A[0] eq $silphone); # boolean.
|
||||
$s = $loopstate;
|
||||
$word_or_eps = $w;
|
||||
while (@A > 0) {
|
||||
$p = shift @A;
|
||||
if(@A > 0) {
|
||||
$ns = $nextstate++;
|
||||
print "$s\t$ns\t$p\t$word_or_eps\n";
|
||||
$word_or_eps = "<eps>";
|
||||
$s = $ns;
|
||||
} else {
|
||||
if(! $is_silence_word) {
|
||||
# This is non-deterministic but relatively compact,
|
||||
# and avoids epsilons.
|
||||
print "$s\t$loopstate\t$p\t$word_or_eps\t$nosilcost\n";
|
||||
print "$s\t$silstate\t$p\t$word_or_eps\t$silcost\n";
|
||||
} else {
|
||||
# no point putting opt-sil after silence word.
|
||||
print "$s\t$loopstate\t$p\t$word_or_eps\n";
|
||||
}
|
||||
$word_or_eps = "<eps>";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
print "$loopstate\t0\n"; # final-cost.
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# make_phones_symtab.pl < lexicon.txt > phones.txt
|
||||
|
||||
|
||||
while(<>) {
|
||||
@A = split(" ", $_);
|
||||
for ($i=2; $i<@A; $i++) {
|
||||
$P{$A[$i]} = 1; # seen it.
|
||||
}
|
||||
}
|
||||
|
||||
print "<eps>\t0\n";
|
||||
$n = 1;
|
||||
foreach $p (sort keys %P) {
|
||||
if($p ne "<eps>") {
|
||||
print "$p\t$n\n";
|
||||
$n++;
|
||||
}
|
||||
}
|
||||
|
||||
print "sil\t$n\n";
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Yanmin Qian Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This file takes as input the file pcdsril.txt that comes with the RM
|
||||
# distribution, and creates the dictionary used in RM training.
|
||||
|
||||
# make_rm_dct.pl pcdsril.txt > dct.txt
|
||||
|
||||
if (@ARGV != 1) {
|
||||
die "usage: make_rm_dct.pl pcdsril.txt > dct.txt\n";
|
||||
}
|
||||
unless (open(IN_FILE, "@ARGV[0]")) {
|
||||
die ("can't open @ARGV[0]");
|
||||
}
|
||||
|
||||
while ($line = <IN_FILE>)
|
||||
{
|
||||
chop($line);
|
||||
if (($line =~ /^[a-z]/))
|
||||
{
|
||||
$line =~ s/\+1//g;
|
||||
@LineArray = split(/\s+/,$line);
|
||||
@LineArray[0] = uc(@LineArray[0]);
|
||||
|
||||
printf "%-16s", @LineArray[0];
|
||||
for ($i = 1; $i < @LineArray; $i ++)
|
||||
{
|
||||
if (@LineArray[$i] eq 'q')
|
||||
{}
|
||||
elsif (@LineArray[$i] eq 'zh')
|
||||
{
|
||||
printf "sh ";
|
||||
}
|
||||
elsif (@LineArray[$i] eq 'eng')
|
||||
{
|
||||
printf "ng ";
|
||||
}
|
||||
elsif (@LineArray[$i] eq 'hv')
|
||||
{
|
||||
printf "hh ";
|
||||
}
|
||||
elsif (@LineArray[$i] eq 'em')
|
||||
{
|
||||
printf "m ";
|
||||
}
|
||||
elsif (@LineArray[$i] eq 'axr')
|
||||
{
|
||||
printf "er ";
|
||||
}
|
||||
elsif (@LineArray[$i] eq 'tcl')
|
||||
{
|
||||
if (@LineArray[$i+1] ne 't')
|
||||
{
|
||||
printf "td ";
|
||||
}
|
||||
}
|
||||
elsif (@LineArray[$i] eq 'dcl')
|
||||
{
|
||||
if (@LineArray[$i+1] ne 'd')
|
||||
{
|
||||
printf "dd ";
|
||||
}
|
||||
}
|
||||
elsif (@LineArray[$i] eq 'kcl')
|
||||
{
|
||||
if (@LineArray[$i+1] ne 'k')
|
||||
{
|
||||
printf "kd ";
|
||||
}
|
||||
}
|
||||
elsif (@LineArray[$i] eq 'pcl')
|
||||
{
|
||||
if (@LineArray[$i+1] ne 'p')
|
||||
{
|
||||
printf "pd ";
|
||||
}
|
||||
}
|
||||
elsif (@LineArray[$i] eq 'bcl')
|
||||
{
|
||||
if (@LineArray[$i+1] ne 'b')
|
||||
{
|
||||
printf "b ";
|
||||
}
|
||||
}
|
||||
elsif (@LineArray[$i] eq 'gcl')
|
||||
{
|
||||
if (@LineArray[$i+1] ne 'g')
|
||||
{
|
||||
printf "g ";
|
||||
}
|
||||
}
|
||||
elsif (@LineArray[$i] eq 't')
|
||||
{
|
||||
if (@LineArray[$i+1] ne 's')
|
||||
{
|
||||
printf "@LineArray[$i] ";
|
||||
}
|
||||
else
|
||||
{
|
||||
printf "ts ";
|
||||
$i++;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf "@LineArray[$i] ";
|
||||
}
|
||||
}
|
||||
printf "\n";
|
||||
}
|
||||
}
|
||||
|
||||
printf "!SIL sil\n";
|
||||
|
||||
close(IN_FILE);
|
||||
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
#!/usr/bin/perl
|
||||
|
||||
# Copyright 2010-2011 Yanmin Qian Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This file takes as input the file wp_gram.txt that comes with the RM
|
||||
# distribution, and creates the language model as an acceptor in FST form.
|
||||
|
||||
# make_rm_lm.pl wp_gram.txt > G.txt
|
||||
|
||||
if (@ARGV != 1) {
|
||||
print "usage: make_rm_lm.pl wp_gram.txt > G.txt\n";
|
||||
exit(0);
|
||||
}
|
||||
unless (open(IN_FILE, "@ARGV[0]")) {
|
||||
die ("can't open @ARGV[0]");
|
||||
}
|
||||
|
||||
|
||||
$flag = 0;
|
||||
$count_wrd = 0;
|
||||
$cnt_ends = 0;
|
||||
$init = "";
|
||||
|
||||
while ($line = <IN_FILE>)
|
||||
{
|
||||
chop($line);
|
||||
|
||||
$line =~ s/ //g;
|
||||
|
||||
if(($line =~ /^>/))
|
||||
{
|
||||
if($flag == 0)
|
||||
{
|
||||
$flag = 1;
|
||||
}
|
||||
$line =~ s/>//g;
|
||||
$hashcnt{$init} = $i;
|
||||
$init = $line;
|
||||
$i = 0;
|
||||
$count_wrd++;
|
||||
@LineArray[$count_wrd - 1] = $init;
|
||||
$hashwrd{$init} = 0;
|
||||
}
|
||||
elsif($flag != 0)
|
||||
{
|
||||
|
||||
$hash{$init}[$i] = $line;
|
||||
$i++;
|
||||
if($line =~ /SENTENCE-END/)
|
||||
{
|
||||
$cnt_ends++;
|
||||
}
|
||||
}
|
||||
else
|
||||
{}
|
||||
}
|
||||
|
||||
$hashcnt{$init} = $i;
|
||||
|
||||
$num = 0;
|
||||
$weight = 0;
|
||||
$init_wrd = "SENTENCE-END";
|
||||
$hashwrd{$init_wrd} = @LineArray;
|
||||
for($i = 0; $i < $hashcnt{$init_wrd}; $i++)
|
||||
{
|
||||
$weight = -log(1/$hashcnt{$init_wrd});
|
||||
$hashwrd{$hash{$init_wrd}[$i]} = $i + 1;
|
||||
print "0 $hashwrd{$hash{$init_wrd}[$i]} $hash{$init_wrd}[$i] $hash{$init_wrd}[$i] $weight\n";
|
||||
}
|
||||
$num = $i;
|
||||
|
||||
for($i = 0; $i < @LineArray; $i++)
|
||||
{
|
||||
if(@LineArray[$i] eq 'SENTENCE-END')
|
||||
{}
|
||||
else
|
||||
{
|
||||
if($hashwrd{@LineArray[$i]} == 0)
|
||||
{
|
||||
$num++;
|
||||
$hashwrd{@LineArray[$i]} = $num;
|
||||
}
|
||||
for($j = 0; $j < $hashcnt{@LineArray[$i]}; $j++)
|
||||
{
|
||||
$weight = -log(1/$hashcnt{@LineArray[$i]});
|
||||
if($hashwrd{$hash{@LineArray[$i]}[$j]} == 0)
|
||||
{
|
||||
$num++;
|
||||
$hashwrd{$hash{@LineArray[$i]}[$j]} = $num;
|
||||
}
|
||||
if($hash{@LineArray[$i]}[$j] eq 'SENTENCE-END')
|
||||
{
|
||||
print "$hashwrd{@LineArray[$i]} $hashwrd{$hash{@LineArray[$i]}[$j]} <eps> <eps> $weight\n"
|
||||
}
|
||||
else
|
||||
{
|
||||
print "$hashwrd{@LineArray[$i]} $hashwrd{$hash{@LineArray[$i]}[$j]} $hash{@LineArray[$i]}[$j] $hash{@LineArray[$i]}[$j] $weight\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
print "$hashwrd{$init_wrd} 0\n";
|
||||
close(IN_FILE);
|
||||
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# Written by Dan Povey 9/21/2010. Apache 2.0 License.
|
||||
|
||||
# This version of make_roots.pl is specialized for RM.
|
||||
|
||||
# This script creates the file roots.txt which is an input to train-tree.cc. It
|
||||
# specifies how the trees are built. The input file phone-sets.txt is a partial
|
||||
# version of roots.txt in which phones are represented by their spelled form, not
|
||||
# their symbol id's. E.g. at input, phone-sets.txt might contain;
|
||||
# shared not-split sil
|
||||
# Any phones not specified in phone-sets.txt but present in phones.txt will
|
||||
# be given a default treatment. If the --separate option is given, we create
|
||||
# a separate tree root for each of them, otherwise they are all lumped in one set.
|
||||
# The arguments shared|not-shared and split|not-split are needed if any
|
||||
# phones are not specified in phone-sets.txt. What they mean is as follows:
|
||||
# if shared=="shared" then we share the tree-root between different HMM-positions
|
||||
# (0,1,2). If split=="split" then we actually do decision tree splitting on
|
||||
# that root, otherwise we forbid decision-tree splitting. (The main reason we might
|
||||
# set this to false is for silence when
|
||||
# we want to ensure that the HMM-positions will remain with a single PDF id.
|
||||
|
||||
|
||||
$separate = 0;
|
||||
if($ARGV[0] eq "--separate") {
|
||||
$separate = 1;
|
||||
shift @ARGV;
|
||||
}
|
||||
|
||||
if(@ARGV != 4) {
|
||||
die "Usage: make_roots.pl [--separate] phones.txt silence-phone-list[integer,colon-separated] shared|not-shared split|not-split > roots.txt\n";
|
||||
}
|
||||
|
||||
|
||||
($phonesfile, $silphones, $shared, $split) = @ARGV;
|
||||
if($shared ne "shared" && $shared ne "not-shared") {
|
||||
die "Third argument must be \"shared\" or \"not-shared\"\n";
|
||||
}
|
||||
if($split ne "split" && $split ne "not-split") {
|
||||
die "Third argument must be \"split\" or \"not-split\"\n";
|
||||
}
|
||||
|
||||
|
||||
|
||||
open(F, "<$phonesfile") || die "Opening file $phonesfile";
|
||||
|
||||
while(<F>) {
|
||||
@A = split(" ", $_);
|
||||
if(@A != 2) {
|
||||
die "Bad line in phones symbol file: ".$_;
|
||||
}
|
||||
if($A[1] != 0) {
|
||||
$symbol2id{$A[0]} = $A[1];
|
||||
$id2symbol{$A[1]} = $A[0];
|
||||
}
|
||||
}
|
||||
|
||||
if($silphones == ""){
|
||||
die "Empty silence phone list in make_roots.pl";
|
||||
}
|
||||
foreach $silphoneid (split(":", $silphones)) {
|
||||
defined $id2symbol{$silphoneid} || die "No such silence phone id $silphoneid";
|
||||
# Give each silence phone its own separate pdfs in each state, but
|
||||
# no sharing (in this recipe; WSJ is different.. in this recipe there
|
||||
#is only one silence phone anyway.)
|
||||
$issil{$silphoneid} = 1;
|
||||
print "not-shared not-split $silphoneid\n";
|
||||
}
|
||||
|
||||
$idlist = "";
|
||||
$remaining_phones = "";
|
||||
|
||||
if($separate){
|
||||
foreach $a (keys %id2symbol) {
|
||||
if(!defined $issil{$a}) {
|
||||
print "$shared $split $a\n";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
print "$shared $split ";
|
||||
foreach $a (keys %id2symbol) {
|
||||
if(!defined $issil{$a}) {
|
||||
print "$a ";
|
||||
}
|
||||
}
|
||||
print "\n";
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# make_words_symtab.pl < G.txt > words.txt
|
||||
|
||||
|
||||
|
||||
|
||||
while(<>) {
|
||||
@A = split(" ", $_);
|
||||
if(@A >= 3) {
|
||||
$W{$A[2]} = 1;
|
||||
}
|
||||
}
|
||||
|
||||
print "<eps>\t0\n";
|
||||
$n = 1;
|
||||
foreach $w (sort keys %W) {
|
||||
if($w ne "<eps>") {
|
||||
print "$w\t$n\n";
|
||||
$n++;
|
||||
}
|
||||
}
|
||||
|
||||
print "!SIL\t$n\n";
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
reorder=true # Dan-style, make false for Mirko+Lukas's decoder.
|
||||
|
||||
for x in 1 2 3; do
|
||||
if [ $1 == "--mono" ]; then
|
||||
monophone_opts="--context-size=1 --central-position=0"
|
||||
shift;
|
||||
fi
|
||||
|
||||
if [ $1 == "--noreorder" ]; then
|
||||
reorder=false # we set this for the Kaldi decoder.
|
||||
shift;
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: scripts/mkgraph.sh <tree> <model> <graphdir>"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ -f path.sh ]; then . path.sh; fi
|
||||
|
||||
tree=$1
|
||||
model=$2
|
||||
dir=$3
|
||||
|
||||
mkdir -p $dir
|
||||
|
||||
tscale=1.0
|
||||
loopscale=0.1
|
||||
|
||||
fsttablecompose data/L_disambig.fst data/G.fst | fstdeterminizestar --use-log=true | \
|
||||
fstminimizeencoded > $dir/LG.fst
|
||||
|
||||
fstisstochastic $dir/LG.fst || echo "warning: LG not stochastic."
|
||||
|
||||
echo "Example string from LG.fst: "
|
||||
echo
|
||||
fstrandgen --select=log_prob $dir/LG.fst | fstprint --isymbols=data/phones_disambig.txt --osymbols=data/words.txt -
|
||||
|
||||
grep '#' data/phones_disambig.txt | awk '{print $2}' > $dir/disambig_phones.list
|
||||
|
||||
fstcomposecontext $monophone_opts \
|
||||
--read-disambig-syms=$dir/disambig_phones.list \
|
||||
--write-disambig-syms=$dir/disambig_ilabels.list \
|
||||
$dir/ilabels < $dir/LG.fst >$dir/CLG.fst
|
||||
|
||||
# for debugging:
|
||||
fstmakecontextsyms data/phones.txt $dir/ilabels > $dir/context_syms.txt
|
||||
echo "Example string from CLG.fst: "
|
||||
echo
|
||||
fstrandgen --select=log_prob $dir/CLG.fst | fstprint --isymbols=$dir/context_syms.txt --osymbols=data/words.txt -
|
||||
|
||||
fstisstochastic $dir/CLG.fst || echo "warning: CLG not stochastic."
|
||||
|
||||
make-ilabel-transducer --write-disambig-syms=$dir/disambig_ilabels_remapped.list $dir/ilabels $tree $model $dir/ilabels.remapped > $dir/ilabel_map.fst
|
||||
|
||||
# Reduce size of CLG by remapping symbols...
|
||||
fsttablecompose $dir/ilabel_map.fst $dir/CLG.fst | fstdeterminizestar --use-log=true \
|
||||
| fstminimizeencoded > $dir/CLG2.fst
|
||||
|
||||
|
||||
cat $dir/CLG2.fst | fstisstochastic || echo "warning: CLG2 is not stochastic."
|
||||
|
||||
make-h-transducer --disambig-syms-out=$dir/disambig_tstate.list \
|
||||
--transition-scale=$tscale $dir/ilabels.remapped $tree $model > $dir/Ha.fst
|
||||
|
||||
|
||||
fsttablecompose $dir/Ha.fst $dir/CLG2.fst | fstdeterminizestar --use-log=true \
|
||||
| fstrmsymbols $dir/disambig_tstate.list | fstrmepslocal | fstminimizeencoded > $dir/HCLGa.fst
|
||||
|
||||
fstisstochastic $dir/HCLGa.fst || echo "HCLGa is not stochastic"
|
||||
|
||||
add-self-loops --self-loop-scale=$loopscale --reorder=$reorder $model < $dir/HCLGa.fst > $dir/HCLG.fst
|
||||
|
||||
if [ $tscale == 1.0 -a $loopscale == 1.0 ]; then
|
||||
# No point doing this test if transition-scale not 1, as it is bound to fail.
|
||||
fstisstochastic $dir/HCLG.fst || echo "Final HCLG is not stochastic."
|
||||
fi
|
||||
|
||||
fstisstochastic $dir/HCLG.fst || echo "Final HCLG is not stochastic."
|
||||
|
||||
|
||||
#The next five lines are debug.
|
||||
# The last two lines of this block print out some alignment info.
|
||||
fstrandgen --select=log_prob $dir/HCLG.fst | fstprint --osymbols=data/words.txt > $dir/rand.txt
|
||||
cat $dir/rand.txt | awk 'BEGIN{printf("0 ");} {if(NF>=3 && $3 != 0){ printf ("%d ",$3); }} END {print ""; }' > $dir/rand_align.txt
|
||||
|
||||
show-alignments data/phones.txt $model ark:$dir/rand_align.txt
|
||||
cat $dir/rand.txt | awk ' {if(NF>=4 && $4 != "<eps>"){ printf ("%s ",$4); }} END {print ""; }'
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This version of mkgraph.sh creates the C fst explicitly.
|
||||
|
||||
reorder=true # Dan-style, make false for Mirko+Lukas's decoder.
|
||||
|
||||
for x in 1 2 3; do
|
||||
if [ $1 == "--mono" ]; then
|
||||
monophone_opts="--context-size=1 --central-position=0"
|
||||
shift;
|
||||
fi
|
||||
|
||||
if [ $1 == "--noreorder" ]; then
|
||||
reorder=false # we set this for the Kaldi decoder.
|
||||
shift;
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: scripts/mkgraph.sh <tree> <model> <graphdir>"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ -f path.sh ]; then . path.sh; fi
|
||||
|
||||
|
||||
tree=$1
|
||||
model=$2
|
||||
dir=$3
|
||||
|
||||
mkdir -p $dir
|
||||
|
||||
tscale=1.0
|
||||
loopscale=0.1
|
||||
|
||||
fsttablecompose data/L_disambig.fst data/G.fst | fstdeterminizestar --use-log=true | \
|
||||
fstminimizeencoded > $dir/LG.fst
|
||||
|
||||
fstisstochastic $dir/LG.fst || echo "warning: LG not stochastic."
|
||||
|
||||
echo "Example string from LG.fst: "
|
||||
echo
|
||||
fstrandgen --select=log_prob $dir/LG.fst | fstprint --isymbols=data/phones_disambig.txt --osymbols=data/words.txt -
|
||||
|
||||
grep '#' data/phones_disambig.txt | awk '{print $2}' > $dir/disambig_phones.list
|
||||
subseq_sym=`tail -1 data/phones_disambig.txt | awk '{print $2+1;}'`
|
||||
cp data/phones_disambig.txt $dir/phones_disambig_subseq.txt
|
||||
echo '$' $subseq_sym >> $dir/phones_disambig_subseq.txt
|
||||
|
||||
fstmakecontextfst --read-disambig-syms=$dir/disambig_phones.list \
|
||||
--write-disambig-syms=$dir/disambig_ilabels.list data/phones.txt $subseq_sym \
|
||||
$dir/ilabels | fstarcsort --sort_type=olabel > $dir/C.fst
|
||||
|
||||
fstaddsubsequentialloop $subseq_sym $dir/LG.fst | \
|
||||
fsttablecompose $dir/C.fst - > $dir/CLG.fst
|
||||
|
||||
|
||||
# for debugging:
|
||||
fstmakecontextsyms data/phones.txt $dir/ilabels > $dir/context_syms.txt
|
||||
echo "Example string from CLG.fst: "
|
||||
echo
|
||||
fstrandgen --select=log_prob $dir/CLG.fst | fstprint --isymbols=$dir/context_syms.txt --osymbols=data/words.txt -
|
||||
|
||||
fstisstochastic $dir/CLG.fst || echo "warning: CLG not stochastic."
|
||||
|
||||
make-ilabel-transducer --write-disambig-syms=$dir/disambig_ilabels_remapped.list $dir/ilabels $tree $model $dir/ilabels.remapped > $dir/ilabel_map.fst
|
||||
|
||||
# Reduce size of CLG by remapping symbols...
|
||||
fstcompose $dir/ilabel_map.fst $dir/CLG.fst | fstdeterminizestar --use-log=true \
|
||||
| fstminimizeencoded > $dir/CLG2.fst
|
||||
|
||||
|
||||
cat $dir/CLG2.fst | fstisstochastic || echo "warning: CLG2 is not stochastic."
|
||||
|
||||
make-h-transducer --disambig-syms-out=$dir/disambig_tstate.list \
|
||||
--transition-scale=$tscale $dir/ilabels.remapped $tree $model > $dir/Ha.fst
|
||||
|
||||
|
||||
fsttablecompose $dir/Ha.fst $dir/CLG2.fst | fstdeterminizestar --use-log=true \
|
||||
| fstrmsymbols $dir/disambig_tstate.list | fstrmepslocal | fstminimizeencoded > $dir/HCLGa.fst
|
||||
|
||||
fstisstochastic $dir/HCLGa.fst || echo "HCLGa is not stochastic"
|
||||
|
||||
add-self-loops --self-loop-scale=$loopscale --reorder=$reorder $model < $dir/HCLGa.fst > $dir/HCLG.fst
|
||||
|
||||
if [ $tscale == 1.0 -a $loopscale == 1.0 ]; then
|
||||
# No point doing this test if transition-scale not 1, as it is bound to fail.
|
||||
fstisstochastic $dir/HCLG.fst || echo "Final HCLG is not stochastic."
|
||||
fi
|
||||
|
||||
fstisstochastic $dir/HCLG.fst || echo "Final HCLG is not stochastic."
|
||||
|
||||
|
||||
#The next five lines are debug.
|
||||
# The last two lines of this block print out some alignment info.
|
||||
fstrandgen --select=log_prob $dir/HCLG.fst | fstprint --osymbols=data/words.txt > $dir/rand.txt
|
||||
cat $dir/rand.txt | awk 'BEGIN{printf("0 ");} {if(NF>=3 && $3 != 0){ printf ("%d ",$3); }} END {print ""; }' > $dir/rand_align.txt
|
||||
show-alignments data/phones.txt $model ark:$dir/rand_align.txt
|
||||
cat $dir/rand.txt | awk ' {if(NF>=4 && $4 != "<eps>"){ printf ("%s ",$4); }} END {print ""; }'
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This script is part of a diagnostic step when using exponential transforms.
|
||||
|
||||
$map=$ARGV[0]; open(M,"<$map")||die "opening map file $map";
|
||||
while(<M>){ @A=split(" ",$_); $map{$A[0]} = $A[1]; }
|
||||
while(<STDIN>){
|
||||
($spk,$warp)=split(" ",$_);
|
||||
$class = int($class/2);
|
||||
defined $map{$spk} || die "No gender info for speaker $spk";
|
||||
$warps{$map{$spk}} = $warps{$map{$spk}} . "$warp ";
|
||||
}
|
||||
@K = sort keys %warps;
|
||||
@K==2||die "wrong number of keys [empty warps file?]";
|
||||
foreach $k ( @K ) {
|
||||
$s = join(" ", sort { $a <=> $b } ( split(" ", $warps{$k}) )) ;
|
||||
print "$k = [ $s ];\n";
|
||||
}
|
||||
# f,m may be reversed below; doesnt matter.
|
||||
foreach $w ( split(" ", $warps{$K[0]}) ) {
|
||||
$nf += 1; $sumf += $w; $sumf2 += $w*$w;
|
||||
}
|
||||
foreach $w ( split(" ", $warps{$K[1]}) ) {
|
||||
$nm += 1; $summ += $w; $summ2 += $w*$w;
|
||||
}
|
||||
$sumf /= $nf; $sumf2 /= $nf;
|
||||
$summ /= $nm; $summ2 /= $nm;
|
||||
$sumf2 -= $sumf*$sumf;
|
||||
$summ2 -= $summ*$summ;
|
||||
$avgwithin = 0.5*($sumf2+$summ2 );
|
||||
$diff = abs($sumf - $summ) / sqrt($avgwithin);
|
||||
print "% class separation is $diff\n";
|
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# creates integer lists of silence and non-silence phones in files,
|
||||
# e.g. silphones.csl="1:2:3 \n"
|
||||
# and nonsilphones.csl="4:5:6:7:...:24\n";
|
||||
|
||||
if(@ARGV != 4) {
|
||||
die "Usage: silphones.pl phones.txt \"sil1 sil2 sil3\" silphones.csl nonsilphones.csl";
|
||||
}
|
||||
|
||||
($symtab, $sillist, $silphones, $nonsilphones) = @ARGV;
|
||||
open(S,"<$symtab") || die "Opening symbol table $symtab";
|
||||
|
||||
|
||||
foreach $s (split(" ", $sillist)) {
|
||||
$issil{$s} = 1;
|
||||
}
|
||||
|
||||
@sil = ();
|
||||
@nonsil = ();
|
||||
while(<S>){
|
||||
@A = split(" ", $_);
|
||||
@A == 2 || die "Bad line $_ in phone-symbol-table file $symtab";
|
||||
($sym, $int) = @A;
|
||||
if($int != 0) {
|
||||
if($issil{$sym}) { push @sil, $int; $seensil{$sym}=1; }
|
||||
else { push @nonsil, $int; }
|
||||
}
|
||||
}
|
||||
|
||||
foreach $k(keys %issil) {
|
||||
if(!$seensil{$k}) { die "No such silence phone $k"; }
|
||||
}
|
||||
open(F, ">$silphones") || die "opening silphones file $silphones";
|
||||
open(G, ">$nonsilphones") || die "opening nonsilphones file $nonsilphones";
|
||||
print F join(":", @sil) . "\n";
|
||||
print G join(":", @nonsil) . "\n";
|
||||
close(F);
|
||||
close(G);
|
||||
if(@sil == 0) { print STDERR "Warning: silphones.pl no silence phones.\n" }
|
||||
if(@nonsil == 0) { print STDERR "Warning: silphones.pl no non-silence phones.\n" }
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
while(<>){
|
||||
@A = split(" ", $_);
|
||||
@A > 1 || die "Invalid line in spk2utt file: $_";
|
||||
$s = shift @A;
|
||||
foreach $u ( @A ) {
|
||||
print "$u $s\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,181 @@
|
|||
#!/usr/bin/perl -w
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
# This program splits up any kind of .scp or archive-type file.
|
||||
# If there is no utt2spk option it will work on any text file and
|
||||
# will split it up with an approximately equal number of lines in
|
||||
# each but.
|
||||
# With the --utt2spk option it will work on anything that has the
|
||||
# utterance-id as the first entry on each line; the utt2spk file is
|
||||
# of the form "utterance speaker" (on each line).
|
||||
# It splits it into equal size chunks as far as it can. If you use
|
||||
# the utt2spk option it will make sure these chunks coincide with
|
||||
# speaker boundaries. In this case, if there are more chunks
|
||||
# than speakers (and in some other circumstances), some of the
|
||||
# resulting chunks will be empty and it
|
||||
# will print a warning.
|
||||
# You will normally call this like:
|
||||
# split_scp.pl scp scp.1 scp.2 scp.3 ...
|
||||
# or
|
||||
# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
|
||||
# Note that you can use this script to split the utt2spk file itself,
|
||||
# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
|
||||
|
||||
if(@ARGV < 2 ) {
|
||||
die "Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ... ";
|
||||
}
|
||||
|
||||
if($ARGV[0] =~ m:^-:) {
|
||||
# Everything inside this block
|
||||
# corresponds to what we do when the --utt2spk option is used.
|
||||
$opt = shift @ARGV;
|
||||
@A = split("=", $opt);
|
||||
if(@A != 2 || $A[0] ne "--utt2spk") {
|
||||
die "split_scp.pl: invalid option $ARGV[0]";
|
||||
}
|
||||
$utt2spk_file = $A[1];
|
||||
open(U, "<$utt2spk_file") || die "Failed to open utt2spk file $utt2spk_file";
|
||||
while(<U>) {
|
||||
@A = split;
|
||||
@A == 2 || die "Bad line $_ in utt2spk file $utt2spk_file";
|
||||
($u,$s) = @A;
|
||||
$utt2spk{$u} = $s;
|
||||
}
|
||||
$inscp = shift @ARGV;
|
||||
open(I, "<$inscp") || die "Opening input scp file $inscp";
|
||||
@spkrs = ();
|
||||
while(<I>) {
|
||||
@A = split;
|
||||
if(@A == 0) { die "Empty or space-only line in scp file $inscp"; }
|
||||
$u = $A[0];
|
||||
$s = $utt2spk{$u};
|
||||
if(!defined $s) { die "No such utterance $u in utt2spk file $utt2spk_file"; }
|
||||
if(!defined $spk_count{$s}) {
|
||||
push @spkrs, $s;
|
||||
$spk_count{$s} = 0;
|
||||
$spk_data{$s} = "";
|
||||
}
|
||||
$spk_count{$s}++;
|
||||
$spk_data{$s} = $spk_data{$s} . $_;
|
||||
}
|
||||
# Now split as equally as possible ..
|
||||
# First allocate spks to files by given approximately
|
||||
# equal #spks.
|
||||
$numspks = @spkrs; # number of speakers.
|
||||
$numscps = @ARGV; # number of output files.
|
||||
$spksperscp = int( ($numspks+($numscps-1)) / $numscps); # the +$(numscps-1) forces rounding up.
|
||||
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
|
||||
$scparray[$scpidx] = []; # [] is array reference.
|
||||
for($n = $spksperscp * $scpidx;
|
||||
$n < $numspks && $n < $spksperscp*($scpidx+1);
|
||||
$n++) {
|
||||
$spk = $spkrs[$n];
|
||||
push @{$scparray[$scpidx]}, $spk;
|
||||
$scpcount[$scpidx] += $spk_count{$spk};
|
||||
}
|
||||
}
|
||||
# Now will try to reassign beginning + ending speakers
|
||||
# to different scp's and see if it gets more balanced.
|
||||
# Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
|
||||
# We can show that if considering changing just 2 scp's, we minimize
|
||||
# this by minimizing the squared difference in sizes. This is
|
||||
# equivalent to minimizing the absolute difference in sizes. This
|
||||
# shows this method is bound to converge.
|
||||
|
||||
$changed = 1;
|
||||
while($changed) {
|
||||
$changed = 0;
|
||||
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
|
||||
# First try to reassign ending spk of this scp.
|
||||
if($scpidx < $numscps-1) {
|
||||
$sz = @{$scparray[$scpidx]};
|
||||
if($sz > 0) {
|
||||
$spk = $scparray[$scpidx]->[$sz-1];
|
||||
$count = $spk_count{$spk};
|
||||
$nutt1 = $scpcount[$scpidx];
|
||||
$nutt2 = $scpcount[$scpidx+1];
|
||||
if( abs( ($nutt2+$count) - ($nutt1-$count))
|
||||
< abs($nutt2 - $nutt1)) { # Would decrease
|
||||
# size-diff by reassigning spk...
|
||||
$scpcount[$scpidx+1] += $count;
|
||||
$scpcount[$scpidx] -= $count;
|
||||
pop @{$scparray[$scpidx]};
|
||||
unshift @{$scparray[$scpidx+1]}, $spk;
|
||||
$changed = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
|
||||
$spk = $scparray[$scpidx]->[0];
|
||||
$count = $spk_count{$spk};
|
||||
$nutt1 = $scpcount[$scpidx-1];
|
||||
$nutt2 = $scpcount[$scpidx];
|
||||
if( abs( ($nutt2-$count) - ($nutt1+$count))
|
||||
< abs($nutt2 - $nutt1)) { # Would decrease
|
||||
# size-diff by reassigning spk...
|
||||
$scpcount[$scpidx-1] += $count;
|
||||
$scpcount[$scpidx] -= $count;
|
||||
shift @{$scparray[$scpidx]};
|
||||
push @{$scparray[$scpidx-1]}, $spk;
|
||||
$changed = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
# Now print out the files...
|
||||
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
|
||||
$scpfn = $ARGV[$scpidx];
|
||||
open(F, ">$scpfn") || die "Could not open scp file $scpfn for writing.";
|
||||
$count = 0;
|
||||
if(@{$scparray[$scpidx]} == 0) {
|
||||
print STDERR "Warning: split_scp.pl producing empty .scp file $scpfn (too many splits and too few speakers?)";
|
||||
}
|
||||
foreach $spk ( @{$scparray[$scpidx]} ) {
|
||||
print F $spk_data{$spk};
|
||||
$count += $spk_count{$spk};
|
||||
}
|
||||
if($count != $scpcount[$scpidx]) { die "Count mismatch [code error]"; }
|
||||
close(F);
|
||||
}
|
||||
} else {
|
||||
# This block is the "normal" case where there is no --utt2spk
|
||||
# option and we just break into equal size chunks.
|
||||
|
||||
$inscp = shift @ARGV;
|
||||
open(I, "<$inscp") || die "Opening input scp file $inscp";
|
||||
|
||||
$numscps = @ARGV; # size of array.
|
||||
@F = ();
|
||||
while(<I>) {
|
||||
push @F, $_;
|
||||
}
|
||||
$numlines = @F;
|
||||
if($numlines == 0) {
|
||||
print STDERR "split_scp.pl: warning: empty input scp file $inscp";
|
||||
}
|
||||
$linesperscp = int( ($numlines+($numscps-1)) / $numscps); # the +$(numscps-1) forces rounding up.
|
||||
# [just doing int() rounds down].
|
||||
for($scpidx = 0; $scpidx < @ARGV; $scpidx++) {
|
||||
$scpfile = $ARGV[$scpidx];
|
||||
open(O, ">$scpfile") || die "Opening output scp file $scpfile";
|
||||
for($n = $linesperscp * $scpidx; $n < $numlines && $n < $linesperscp*($scpidx+1); $n++) {
|
||||
print O $F[$n];
|
||||
}
|
||||
close(O) || die "Closing scp file $scpfile";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
#!/usr/bin/perl -w
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This program selects a subset of N elements in the scp.
|
||||
# It selects them evenly from throughout the scp, in order to
|
||||
# avoid selecting too many from the same speaker.
|
||||
# It prints them on the standard output.
|
||||
|
||||
if(@ARGV < 2 ) {
|
||||
die "Usage: subset_scp.pl N in.scp ";
|
||||
}
|
||||
|
||||
$N = shift @ARGV;
|
||||
if($N == 0) {
|
||||
die "First command-line parameter to subset_scp.pl must be an integer, got \"$N\"";
|
||||
}
|
||||
$inscp = shift @ARGV;
|
||||
open(I, "<$inscp") || die "Opening input scp file $inscp";
|
||||
|
||||
@F = ();
|
||||
while(<I>) {
|
||||
push @F, $_;
|
||||
}
|
||||
$numlines = @F;
|
||||
if($N > $numlines) {
|
||||
die "You requested from subset_scp.pl more elements than available: $N > $numlines";
|
||||
}
|
||||
|
||||
sub select_n {
|
||||
my ($start,$end,$num_needed) = @_;
|
||||
my $diff = $end - $start;
|
||||
if($num_needed > $diff) { die "select_n: code error"; }
|
||||
if($diff == 1 ) {
|
||||
if($num_needed > 0) {
|
||||
print $F[$start];
|
||||
}
|
||||
} else {
|
||||
my $halfdiff = int($diff/2);
|
||||
my $halfneeded = int($num_needed/2);
|
||||
select_n($start, $start+$halfdiff, $halfneeded);
|
||||
select_n($start+$halfdiff, $end, $num_needed - $halfneeded);
|
||||
}
|
||||
}
|
||||
select_n(0, $numlines, $N);
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
$ignore_oov = 0;
|
||||
$ignore_first_field = 0;
|
||||
for($x = 0; $x < 2; $x++) {
|
||||
if($ARGV[0] eq "--ignore-oov") { $ignore_oov = 1; shift @ARGV; }
|
||||
if($ARGV[0] eq "--ignore-first-field") { $ignore_first_field = 1; shift @ARGV; }
|
||||
}
|
||||
|
||||
$symtab = shift @ARGV;
|
||||
if(!defined $symtab) {
|
||||
die "Usage: sym2int.pl symtab [input transcriptions] > output transcriptions\n";
|
||||
}
|
||||
open(F, "<$symtab") || die "Error opening symbol table file $symtab";
|
||||
while(<F>) {
|
||||
@A = split(" ", $_);
|
||||
@A == 2 || die "bad line in symbol table file: $_";
|
||||
$sym2int{$A[0]} = $A[1] + 0;
|
||||
}
|
||||
|
||||
while(<>) {
|
||||
@A = split(" ", $_);
|
||||
if(@A == 0) {
|
||||
die "Empty line in transcriptions input.";
|
||||
}
|
||||
if($ignore_first_field) {
|
||||
$key = shift @A;
|
||||
print $key . " ";
|
||||
}
|
||||
foreach $a (@A) {
|
||||
$i = $sym2int{$a};
|
||||
if(!defined ($i)) {
|
||||
if($ignore_oov) {
|
||||
print $a . " " ;
|
||||
} else {
|
||||
die "sym2int.pl: undefined symbol $a\n";
|
||||
}
|
||||
}
|
||||
print $i . " ";
|
||||
}
|
||||
print "\n";
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/perl
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
while(<>){
|
||||
@A = split(" ", $_);
|
||||
@A == 2 || die "Invalid line in utt2spk file: $_";
|
||||
($u,$s) = @A;
|
||||
if(!$seen_spk{$s}) {
|
||||
$seen_spk{$s} = 1;
|
||||
push @spklist, $s;
|
||||
}
|
||||
$uttlist{$s} = $uttlist{$s} . "$u ";
|
||||
}
|
||||
foreach $s (@spklist) {
|
||||
$l = $uttlist{$s};
|
||||
$l =~ s: $::; # remove trailing space.
|
||||
print "$s $l\n";
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Monophone decoding script.
|
||||
|
||||
if [ -f path.sh ]; then . path.sh; fi
|
||||
dir=exp/decode_mono
|
||||
tree=exp/mono/tree
|
||||
mkdir -p $dir
|
||||
model=exp/mono/final.mdl
|
||||
graphdir=exp/graph_mono
|
||||
|
||||
scripts/mkgraph.sh --mono $tree $model $graphdir
|
||||
|
||||
for test in mar87 oct87 feb89 oct89 feb91 sep92; do
|
||||
(
|
||||
feats="ark:add-deltas --print-args=false scp:data/test_${test}.scp ark:- |"
|
||||
|
||||
gmm-decode-faster --beam=20.0 --acoustic-scale=0.08333 --word-symbol-table=data/words.txt $model $graphdir/HCLG.fst "$feats" ark,t:$dir/test_${test}.tra ark,t:$dir/test_${test}.ali 2> $dir/decode_${test}.log
|
||||
|
||||
# the ,p option lets it score partial output without dying..
|
||||
scripts/sym2int.pl --ignore-first-field data/words.txt data_prep/test_${test}_trans.txt | \
|
||||
compute-wer --mode=present ark:- ark,p:$dir/test_${test}.tra >& $dir/wer_${test}
|
||||
) &
|
||||
done
|
||||
|
||||
wait
|
||||
|
||||
grep WER $dir/wer_* | \
|
||||
awk '{n=n+$4; d=d+$6} END{ printf("Average WER is %f (%d / %d) \n", (100.0*n)/d, n, d); }' \
|
||||
> $dir/wer
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# To be run from .. (one directory up from here)
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "usage: make_mfcc_test.sh <abs-path-to-tmpdir>"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ -f path.sh ]; then . path.sh; fi
|
||||
|
||||
dir=exp/make_mfcc
|
||||
mkdir -p $dir
|
||||
root_out=$1
|
||||
mkdir -p $root_out
|
||||
|
||||
for test in mar87 oct87 feb89 oct89 feb91 sep92; do
|
||||
scpin=data_prep/test_${test}_wav.scp
|
||||
# Making it like this so it works for others on the BUT filesystem.
|
||||
# It will generate the correct scp file without running the feature extraction.
|
||||
log=$dir/make_mfcc_test_${test}.log
|
||||
(
|
||||
compute-mfcc-feats --verbose=2 --config=conf/mfcc.conf scp:$scpin ark,scp:$root_out/test_${test}_raw_mfcc.ark,$root_out/test_${test}_raw_mfcc.scp 2> $log || tail $log
|
||||
cp $root_out/test_${test}_raw_mfcc.scp data/test_${test}.scp
|
||||
) &
|
||||
done
|
||||
|
||||
wait
|
||||
|
||||
echo "If the above produced no output on the screen, it succeeded."
|
|
@ -0,0 +1,43 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# To be run from .. (one directory up from here)
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "usage: make_mfcc_train.sh <abs-path-to-tmpdir>";
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ -f path.sh ]; then . path.sh; fi
|
||||
|
||||
scpin=data_prep/train_wav.scp
|
||||
dir=exp/make_mfcc
|
||||
mkdir -p $dir
|
||||
root_out=$1
|
||||
mkdir -p $root_out
|
||||
|
||||
scripts/split_scp.pl $scpin $dir/train_wav{1,2,3,4}.scp
|
||||
|
||||
for n in 1 2 3 4; do # Use 4 CPUs
|
||||
log=$dir/make_mfcc_train.$n.log
|
||||
compute-mfcc-feats --verbose=2 --config=conf/mfcc.conf scp:$dir/train_wav${n}.scp ark,scp:$root_out/train_raw_mfcc${n}.ark,$root_out/train_raw_mfcc${n}.scp 2> $log || tail $log &
|
||||
done
|
||||
|
||||
wait;
|
||||
|
||||
cat $root_out/train_raw_mfcc{1,2,3,4}.scp > data/train.scp
|
||||
|
||||
echo "If the above produced no output on the screen, it succeeded."
|
|
@ -0,0 +1,66 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2010-2011 Microsoft Corporation
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# The output of this script is the symbol tables data/{words.txt,phones.txt},
|
||||
# and the grammars and lexicons data/{L,G}{,_disambig}.fst
|
||||
|
||||
# To be run from ..
|
||||
if [ -f path.sh ]; then . path.sh; fi
|
||||
|
||||
cp data_prep/G.txt data/
|
||||
scripts/make_words_symtab.pl < data/G.txt > data/words.txt
|
||||
cp data_prep/lexicon.txt data/
|
||||
|
||||
|
||||
scripts/make_phones_symtab.pl < data/lexicon.txt > data/phones.txt
|
||||
|
||||
silphones="sil"; # This would in general be a space-separated list of all silence phones. E.g. "sil vn"
|
||||
# Generate colon-separated lists of silence and non-silence phones.
|
||||
scripts/silphones.pl data/phones.txt "$silphones" data/silphones.csl data/nonsilphones.csl
|
||||
|
||||
ndisambig=`scripts/add_lex_disambig.pl data/lexicon.txt data/lexicon_disambig.txt`
|
||||
scripts/add_disambig.pl data/phones.txt $ndisambig > data/phones_disambig.txt
|
||||
|
||||
|
||||
# Create train transcripts in integer format:
|
||||
cat data_prep/train_trans.txt | \
|
||||
scripts/sym2int.pl --ignore-first-field data/words.txt > data/train.tra
|
||||
|
||||
|
||||
# Get lexicon in FST format.
|
||||
|
||||
# silprob = 0.5: same prob as word.
|
||||
scripts/make_lexicon_fst.pl data/lexicon.txt 0.5 sil | fstcompile --isymbols=data/phones.txt --osymbols=data/words.txt --keep_isymbols=false --keep_osymbols=false | fstarcsort --sort_type=olabel > data/L.fst
|
||||
|
||||
scripts/make_lexicon_fst.pl data/lexicon_disambig.txt 0.5 sil | fstcompile --isymbols=data/phones_disambig.txt --osymbols=data/words.txt --keep_isymbols=false --keep_osymbols=false | fstarcsort --sort_type=olabel > data/L_disambig.fst
|
||||
|
||||
fstcompile --isymbols=data/words.txt --osymbols=data/words.txt --keep_isymbols=false --keep_osymbols=false data/G.txt > data/G.fst
|
||||
|
||||
# Checking that G is stochastic [note, it wouldn't be for an Arpa]
|
||||
fstisstochastic data/G.fst || echo Error
|
||||
|
||||
|
||||
# Checking that disambiguated lexicon times G is determinizable
|
||||
fsttablecompose data/L_disambig.fst data/G.fst | fstdeterminize >/dev/null || echo Error
|
||||
|
||||
# Checking that LG is stochastic:
|
||||
fsttablecompose data/L.fst data/G.fst | fstisstochastic || echo Error
|
||||
|
||||
## Check lexicon.
|
||||
## just have a look and make sure it seems sane.
|
||||
fstprint --isymbols=data/phones.txt --osymbols=data/words.txt data/L.fst | head
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2010-2011 Microsoft Corporation Arnab Ghoshal
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
# MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
# See the Apache 2 License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# To be run from ..
|
||||
if [ -f path.sh ]; then . path.sh; fi
|
||||
|
||||
# Train the monophone on a subset-- no point using all the data.
|
||||
dir=exp/mono
|
||||
n=1000
|
||||
feats="ark:add-deltas --print-args=false scp:$dir/train.scp ark:- |"
|
||||
# need to quote when passing as an argument, as in "$feats",
|
||||
# since it has spaces in it.
|
||||
scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1"
|
||||
|
||||
numiters=30 # Number of iterations of training
|
||||
maxiterinc=20 # Last iter to increase #Gauss on.
|
||||
numgauss=250 # Initial num-Gauss (must be more than #states=3*phones).
|
||||
totgauss=1000 # Target #Gaussians.
|
||||
incgauss=$[($totgauss-$numgauss)/$maxiterinc] # per-iter increment for #Gauss
|
||||
realign_iters="1 2 3 4 5 6 7 8 9 10 12 15 20 25";
|
||||
|
||||
|
||||
mkdir -p $dir
|
||||
scripts/subset_scp.pl $n data/train.scp > $dir/train.scp
|
||||
|
||||
|
||||
silphones=`cat data/silphones.csl | sed 's/:/ /g'`
|
||||
nonsilphones=`cat data/nonsilphones.csl | sed 's/:/ /g'`
|
||||
cat conf/topo.proto | sed "s:NONSILENCEPHONES:$nonsilphones:" | sed "s:SILENCEPHONES:$silphones:" > $dir/topo
|
||||
|
||||
gmm-init-mono '--train-feats=ark:head -10 data/train.scp | add-deltas scp:- ark:- |' $dir/topo 39 $dir/0.mdl $dir/tree 2> $dir/init.out || exit 1;
|
||||
|
||||
|
||||
echo "Compiling training graphs"
|
||||
compile-train-graphs $dir/tree $dir/0.mdl data/L.fst \
|
||||
"ark:scripts/subset_scp.pl $n data/train.tra|" \
|
||||
"ark:|gzip -c >$dir/graphs.fsts.gz" 2>$dir/compile_graphs.log || exit 1
|
||||
|
||||
echo Pass 0
|
||||
|
||||
align-equal-compiled "ark:gunzip -c $dir/graphs.fsts.gz|" "$feats" \
|
||||
ark,t,f:- 2>$dir/align.0.log | \
|
||||
gmm-acc-stats-ali --binary=true $dir/0.mdl "$feats" ark:- \
|
||||
$dir/0.acc 2> $dir/acc.0.log || exit 1;
|
||||
|
||||
# In the following steps, the --min-gaussian-occupancy=3 option is important, otherwise
|
||||
# we fail to est "rare" phones and later on, they never align properly.
|
||||
gmm-est --min-gaussian-occupancy=3 --mix-up=$numgauss \
|
||||
$dir/0.mdl $dir/0.acc $dir/1.mdl 2> $dir/update.0.log || exit 1;
|
||||
|
||||
rm $dir/0.acc
|
||||
|
||||
|
||||
beam=4 # will change to 8 below after 1st pass
|
||||
x=1
|
||||
while [ $x -lt $numiters ]; do
|
||||
echo "Pass $x"
|
||||
if echo $realign_iters | grep -w $x >/dev/null; then
|
||||
echo "Aligning data"
|
||||
gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$[$beam*4] $dir/$x.mdl \
|
||||
"ark:gunzip -c $dir/graphs.fsts.gz|" "$feats" t,ark:$dir/cur.ali \
|
||||
2> $dir/align.$x.log || exit 1;
|
||||
fi
|
||||
gmm-acc-stats-ali --binary=false $dir/$x.mdl "$feats" ark:$dir/cur.ali $dir/$x.acc 2> $dir/acc.$x.log || exit 1;
|
||||
gmm-est --mix-up=$numgauss $dir/$x.mdl $dir/$x.acc $dir/$[$x+1].mdl 2> $dir/update.$x.log || exit 1;
|
||||
rm $dir/$x.mdl $dir/$x.acc
|
||||
if [ $x -le $maxiterinc ]; then
|
||||
numgauss=$[$numgauss+$incgauss];
|
||||
fi
|
||||
beam=8
|
||||
x=$[$x+1]
|
||||
done
|
||||
|
||||
( cd $dir; rm final.mdl 2>/dev/null; ln -s $x.mdl final.mdl )
|
||||
|
||||
# example of showing the alignments:
|
||||
# show-alignments data/phones.txt $dir/30.mdl ark:$dir/cur.ali | head -4
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
#!/bin/bash
|
||||
|
||||
# To be run from ..
|
||||
if [ -f path.sh ]; then . path.sh; fi
|
||||
|
||||
dir=exp/nnet
|
||||
mkdir -p $dir/{log,nnet}
|
||||
|
||||
#use following features and alignments
|
||||
cp exp/mono/train.scp exp/mono/cur.ali $dir
|
||||
head -n 800 $dir/train.scp > $dir/train.scp.tr
|
||||
tail -n 200 $dir/train.scp > $dir/train.scp.cv
|
||||
feats="ark:add-deltas --print-args=false scp:$dir/train.scp ark:- |"
|
||||
feats_tr="ark:add-deltas --print-args=false scp:$dir/train.scp.tr ark:- |"
|
||||
feats_cv="ark:add-deltas --print-args=false scp:$dir/train.scp.cv ark:- |"
|
||||
labels="ark:$dir/cur.ali"
|
||||
|
||||
#compute per utterance CMVN
|
||||
cmvn="ark:$dir/cmvn.ark"
|
||||
compute-cmvn-stats "$feats" $cmvn
|
||||
feats_tr="$feats_tr apply-cmvn --print-args=false --norm-vars=true $cmvn ark:- ark:- |"
|
||||
feats_cv="$feats_cv apply-cmvn --print-args=false --norm-vars=true $cmvn ark:- ark:- |"
|
||||
|
||||
|
||||
#initialize the nnet
|
||||
mlp_init=$dir/nnet.init
|
||||
scripts/gen_mlp_init.py --dim=39:512:300 --gauss --negbias > $mlp_init
|
||||
|
||||
#global config for trainig
|
||||
max_iters=20
|
||||
start_halving_inc=0.5
|
||||
end_halving_inc=0.1
|
||||
lrate=0.001
|
||||
|
||||
|
||||
|
||||
nnet-train-xent-hardlab-perutt --cross-validate=true $mlp_init "$feats_cv" "$labels" &> $dir/log/prerun.log
|
||||
if [ $? != 0 ]; then cat $dir/log/prerun.log; exit 1; fi
|
||||
acc=$(cat $dir/log/prerun.log | grep Xent | tail -n 1 | cut -d'[' -f 2 | cut -d'%' -f 1)
|
||||
echo CROSSVAL PRERUN ACCURACY $acc
|
||||
|
||||
mlp_best=$mlp_init
|
||||
mlp_base=${mlp_init##*/}; mlp_base=${mlp_base%.*}
|
||||
halving=0
|
||||
for iter in $(seq -w $max_iters); do
|
||||
mlp_next=$dir/nnet/${mlp_base}_iter${iter}
|
||||
nnet-train-xent-hardlab-perutt --learn-rate=$lrate $mlp_best "$feats_tr" "$labels" $mlp_next &> $dir/log/iter$iter.log
|
||||
if [ $? != 0 ]; then cat $dir/log/iter$iter.log; exit 1; fi
|
||||
tr_acc=$(cat $dir/log/iter$iter.log | grep Xent | tail -n 1 | cut -d'[' -f 2 | cut -d'%' -f 1)
|
||||
echo TRAIN ITERATION $iter ACCURACY $tr_acc LRATE $lrate
|
||||
nnet-train-xent-hardlab-perutt --cross-validate=true $mlp_next "$feats_cv" "$labels" 1>>$dir/log/iter$iter.log 2>>$dir/log/iter$iter.log
|
||||
if [ $? != 0 ]; then cat $dir/log/iter$iter.log; exit 1; fi
|
||||
|
||||
#accept or reject new parameters
|
||||
acc_new=$(cat $dir/log/iter$iter.log | grep Xent | tail -n 1 | cut -d'[' -f 2 | cut -d'%' -f 1)
|
||||
echo CROSSVAL ITERATION $iter ACCURACY $acc_new
|
||||
acc_prev=$acc
|
||||
if [ 1 == $(awk 'BEGIN{print('$acc_new' > '$acc')}') ]; then
|
||||
acc=$acc_new
|
||||
mlp_best=$dir/nnet/$mlp_base.iter${iter}_tr$(printf "%.5g" $tr_acc)_cv$(printf "%.5g" $acc_new)
|
||||
mv $mlp_next $mlp_best
|
||||
echo nnet $mlp_best accepted
|
||||
else
|
||||
mlp_reject=$dir/nnet/$mlp_base.iter${iter}_tr$(printf "%.5g" $tr_acc)_cv$(printf "%.5g" $acc_new)
|
||||
mv $mlp_next $mlp_reject
|
||||
echo nnet $mlp_reject rejected
|
||||
fi
|
||||
|
||||
#stopping criterion
|
||||
if [[ 1 == $halving && 1 == $(awk 'BEGIN{print('$acc' < '$acc_prev'+'$end_halving_inc')}') ]]; then
|
||||
echo finished, too small improvement $(awk 'BEGIN{print('$acc'-'$acc_prev')}')
|
||||
break
|
||||
fi
|
||||
|
||||
#start annealing when improvement is low
|
||||
if [ 1 == $(awk 'BEGIN{print('$acc' < '$acc_prev'+'$start_halving_inc')}') ]; then
|
||||
halving=1
|
||||
fi
|
||||
|
||||
#do annealing
|
||||
if [ 1 == $halving ]; then
|
||||
lrate=$(awk 'BEGIN{print('$lrate'*0.5)}')
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $mlp_best != $mlp_init ]; then
|
||||
iter=$(echo $mlp_best | sed 's/^.*iter\([0-9][0-9]*\).*$/\1/')
|
||||
fi
|
||||
mlp_final=$dir/${mlp_base}_final_iter${iter:-0}_acc${acc}
|
||||
cp $mlp_best $mlp_final
|
||||
echo final network $mlp_final
|
||||
|
|
@ -7,7 +7,8 @@
|
|||
|
||||
SUBDIRS = base matrix util feat tree optimization gmm transform sgmm \
|
||||
fstext hmm lm decoder \
|
||||
bin fstbin gmmbin fgmmbin sgmmbin featbin
|
||||
bin fstbin gmmbin fgmmbin sgmmbin featbin \
|
||||
nnet nnetbin
|
||||
|
||||
all: $(SUBDIRS)
|
||||
echo Done
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
|
||||
|
||||
all:
|
||||
|
||||
include ../kaldi.mk
|
||||
|
||||
TESTFILES = nnet-test
|
||||
|
||||
OBJFILES = nnet-nnet.o nnet-component.o nnet-loss.o
|
||||
|
||||
LIBFILE = kaldi-nnet.a
|
||||
|
||||
all: $(LIBFILE) $(TESTFILES)
|
||||
|
||||
$(LIBFILE): $(OBJFILES)
|
||||
$(AR) -cru $(LIBFILE) $(OBJFILES)
|
||||
$(RANLIB) $(LIBFILE)
|
||||
|
||||
|
||||
|
||||
$(TESTFILES): $(LIBFILE) ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../util/kaldi-util.a
|
||||
|
||||
# Rule below would expand to, e.g.:
|
||||
# ../base/kaldi-base.a:
|
||||
# make -c ../base kaldi-base.a
|
||||
# -c option to make is same as changing directory.
|
||||
%.a:
|
||||
$(MAKE) -C ${@D} ${@F}
|
||||
|
||||
clean:
|
||||
rm *.o *.a $(TESTFILES)
|
||||
|
||||
test: $(TESTFILES)
|
||||
for x in $(TESTFILES); do ./$$x >&/dev/null || (echo "***test $$x failed***"; exit 1); done
|
||||
echo Tests succeeded
|
||||
|
||||
.valgrind: $(TESTFILES)
|
||||
|
||||
|
||||
depend:
|
||||
-$(CXX) -M $(CXXFLAGS) *.cc > .depend.mk
|
||||
|
||||
# removing automatic making of "depend" as it's quite slow.
|
||||
#.depend.mk: depend
|
||||
|
||||
-include .depend.mk
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
// nnet/nnet-activation.h
|
||||
|
||||
// Copyright 2011 Karel Vesely
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#ifndef KALDI_NNET_ACTIVATION_H
|
||||
#define KALDI_NNET_ACTIVATION_H
|
||||
|
||||
#include "nnet/nnet-component.h"
|
||||
namespace kaldi {
|
||||
|
||||
class Sigmoid : public Component {
|
||||
public:
|
||||
Sigmoid(MatrixIndexT dim_in, MatrixIndexT dim_out, Nnet* nnet)
|
||||
: Component(dim_in, dim_out, nnet)
|
||||
{ }
|
||||
~Sigmoid()
|
||||
{ }
|
||||
|
||||
ComponentType GetType() const {
|
||||
return kSigmoid;
|
||||
}
|
||||
|
||||
void PropagateFnc(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) {
|
||||
//y = 1/(1+e^-x)
|
||||
for(MatrixIndexT r=0; r<out->NumRows(); r++) {
|
||||
for(MatrixIndexT c=0; c<out->NumCols(); c++) {
|
||||
(*out)(r,c) = 1.0/(1.0+exp(-in(r,c)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BackpropagateFnc(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err) {
|
||||
//ey = y(1-y)ex
|
||||
const Matrix<BaseFloat>& y = nnet_->PropagateBuffer()[nnet_->IndexOfLayer(*this)+1];
|
||||
|
||||
for(MatrixIndexT r=0; r<out_err->NumRows(); r++) {
|
||||
for(MatrixIndexT c=0; c<out_err->NumCols(); c++) {
|
||||
(*out_err)(r,c) = y(r,c)*(1.0-y(r,c))*in_err(r,c);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class Softmax : public Component {
|
||||
public:
|
||||
Softmax(MatrixIndexT dim_in, MatrixIndexT dim_out, Nnet* nnet)
|
||||
: Component(dim_in, dim_out, nnet)
|
||||
{ }
|
||||
~Softmax()
|
||||
{ }
|
||||
|
||||
ComponentType GetType() const {
|
||||
return kSoftmax;
|
||||
}
|
||||
|
||||
void PropagateFnc(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) {
|
||||
//y = e^x_j/sum_j(e^x_j)
|
||||
out->CopyFromMat(in);
|
||||
for(MatrixIndexT r=0; r<out->NumRows(); r++) {
|
||||
out->Row(r).ApplySoftMax();
|
||||
}
|
||||
}
|
||||
|
||||
void BackpropagateFnc(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err) {
|
||||
//simply copy the error
|
||||
//(ie. assume crossentropy error function,
|
||||
// while in_err contains (net_output-target) :
|
||||
// this is already derivative of the error with
|
||||
// respect to activations of last layer neurons)
|
||||
out_err->CopyFromMat(in_err);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
// nnet/nnet-biasedlinearity.h
|
||||
|
||||
// Copyright 2011 Karel Vesely
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#ifndef KALDI_NNET_BIASEDLINEARITY_H
|
||||
#define KALDI_NNET_BIASEDLINEARITY_H
|
||||
|
||||
|
||||
#include "nnet/nnet-component.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
class BiasedLinearity : public UpdatableComponent {
|
||||
public:
|
||||
BiasedLinearity(MatrixIndexT dim_in, MatrixIndexT dim_out, Nnet* nnet)
|
||||
: UpdatableComponent(dim_in, dim_out, nnet),
|
||||
linearity_(dim_out,dim_in), bias_(dim_out),
|
||||
linearity_corr_(dim_out,dim_in), bias_corr_(dim_out)
|
||||
{ }
|
||||
~BiasedLinearity()
|
||||
{ }
|
||||
|
||||
ComponentType GetType() const {
|
||||
return kBiasedLinearity;
|
||||
}
|
||||
|
||||
void ReadData(std::istream& is, bool binary) {
|
||||
linearity_.Read(is,binary);
|
||||
bias_.Read(is,binary);
|
||||
|
||||
KALDI_ASSERT(linearity_.NumRows() == output_dim_);
|
||||
KALDI_ASSERT(linearity_.NumCols() == input_dim_);
|
||||
KALDI_ASSERT(bias_.Dim() == output_dim_);
|
||||
}
|
||||
|
||||
void WriteData(std::ostream& os, bool binary) const {
|
||||
linearity_.Write(os,binary);
|
||||
bias_.Write(os,binary);
|
||||
}
|
||||
|
||||
void PropagateFnc(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) {
|
||||
//precopy bias
|
||||
for (MatrixIndexT i=0; i<out->NumRows(); i++) {
|
||||
out->CopyRowFromVec(bias_,i);
|
||||
}
|
||||
//multiply by weights^t
|
||||
out->AddMatMat(1.0,in,kNoTrans,linearity_,kTrans,1.0);
|
||||
}
|
||||
|
||||
void BackpropagateFnc(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err) {
|
||||
//multiply error by weights
|
||||
out_err->AddMatMat(1.0,in_err,kNoTrans,linearity_,kNoTrans,0.0);
|
||||
}
|
||||
|
||||
|
||||
void Update(const Matrix<BaseFloat>& input, const Matrix<BaseFloat>& err) {
|
||||
|
||||
//compute gradient
|
||||
linearity_corr_.AddMatMat(1.0,err,kTrans,input,kNoTrans,momentum_);
|
||||
bias_corr_.Scale(momentum_);
|
||||
bias_corr_.AddRowSumMat(err);
|
||||
//l2 regularization
|
||||
if(l2_penalty_ != 0.0) {
|
||||
linearity_.AddMat(-learn_rate_*l2_penalty_*input.NumRows(),linearity_);
|
||||
}
|
||||
//l1 regularization
|
||||
if(l1_penalty_ != 0.0) {
|
||||
BaseFloat l1 = learn_rate_*input.NumRows()*l1_penalty_;
|
||||
for(MatrixIndexT r=0; r<linearity_.NumRows(); r++) {
|
||||
for(MatrixIndexT c=0; c<linearity_.NumCols(); c++) {
|
||||
if(linearity_(r,c)==0.0) continue; //skip L1 if zero weight!
|
||||
BaseFloat l1sign = l1;
|
||||
if(linearity_(r,c) < 0.0)
|
||||
l1sign = -l1;
|
||||
BaseFloat before = linearity_(r,c);
|
||||
BaseFloat after = linearity_(r,c)-learn_rate_*linearity_corr_(r,c)-l1sign;
|
||||
if((after > 0.0) ^ (before > 0.0)) {
|
||||
linearity_(r,c) = 0.0;
|
||||
linearity_corr_(r,c) = 0.0;
|
||||
} else {
|
||||
linearity_(r,c) -= l1sign;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
//update
|
||||
linearity_.AddMat(-learn_rate_,linearity_corr_);
|
||||
bias_.AddVec(-learn_rate_,bias_corr_);
|
||||
|
||||
/*
|
||||
std::cout <<"I"<< input.Row(0);
|
||||
std::cout <<"E"<< err.Row(0);
|
||||
std::cout <<"CORL"<< linearity_corr_.Row(0);
|
||||
std::cout <<"CORB"<< bias_corr_;
|
||||
std::cout <<"L"<< linearity_.Row(0);
|
||||
std::cout <<"B"<< bias_;
|
||||
std::cout << "\n";
|
||||
*/
|
||||
|
||||
//std::cout << l1_penalty_ << l2_penalty_ << momentum_ << learn_rate_ << "\n";
|
||||
}
|
||||
|
||||
private:
|
||||
Matrix<BaseFloat> linearity_;
|
||||
Vector<BaseFloat> bias_;
|
||||
|
||||
Matrix<BaseFloat> linearity_corr_;
|
||||
Vector<BaseFloat> bias_corr_;
|
||||
};
|
||||
|
||||
} //namespace
|
||||
|
||||
#endif
|
|
@ -0,0 +1,98 @@
|
|||
// nnet/nnet-component.h
|
||||
|
||||
// Copyright 2011 Karel Vesely
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "nnet/nnet-component.h"
|
||||
|
||||
#include "nnet/nnet-nnet.h"
|
||||
#include "nnet/nnet-activation.h"
|
||||
#include "nnet/nnet-biasedlinearity.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
|
||||
const struct Component::key_value Component::kMarkerMap[] = {
|
||||
{ Component::kBiasedLinearity,"<biasedlinearity>" },
|
||||
{ Component::kSigmoid,"<sigmoid>" },
|
||||
{ Component::kSoftmax,"<softmax>" }
|
||||
};
|
||||
|
||||
|
||||
const char* Component::TypeToMarker(ComponentType t) {
|
||||
int32 N=sizeof(kMarkerMap)/sizeof(kMarkerMap[0]);
|
||||
for(int i=0; i<N; i++) {
|
||||
if(kMarkerMap[i].key == t)
|
||||
return kMarkerMap[i].value;
|
||||
}
|
||||
KALDI_ERR << "Unknown type" << t;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Component::ComponentType Component::MarkerToType(const std::string& s) {
|
||||
int32 N=sizeof(kMarkerMap)/sizeof(kMarkerMap[0]);
|
||||
for(int i=0; i<N; i++) {
|
||||
if(0 == strcmp(kMarkerMap[i].value,s.c_str()))
|
||||
return kMarkerMap[i].key;
|
||||
}
|
||||
KALDI_ERR << "Unknown marker" << s;
|
||||
return kUnknown;
|
||||
}
|
||||
|
||||
|
||||
Component* Component::Read(std::istream& is, bool binary, Nnet* nnet) {
|
||||
int32 dim_out, dim_in;
|
||||
std::string token;
|
||||
|
||||
int first_char = PeekMarker(is,binary);
|
||||
if(first_char == EOF) return NULL;
|
||||
|
||||
ReadMarker(is,binary,&token);
|
||||
Component::ComponentType comp_type = Component::MarkerToType(token);
|
||||
|
||||
ReadBasicType(is,binary,&dim_out);
|
||||
ReadBasicType(is,binary,&dim_in);
|
||||
|
||||
Component* p_comp;
|
||||
switch(comp_type) {
|
||||
case Component::kBiasedLinearity :
|
||||
p_comp = new BiasedLinearity(dim_in,dim_out,nnet);
|
||||
break;
|
||||
case Component::kSigmoid :
|
||||
p_comp = new Sigmoid(dim_in,dim_out,nnet);
|
||||
break;
|
||||
case Component::kSoftmax :
|
||||
p_comp = new Softmax(dim_in,dim_out,nnet);
|
||||
break;
|
||||
case Component::kUnknown :
|
||||
default :
|
||||
KALDI_ERR << "Missing type: " << token;
|
||||
}
|
||||
|
||||
p_comp->ReadData(is,binary);
|
||||
return p_comp;
|
||||
}
|
||||
|
||||
|
||||
void Component::Write(std::ostream& os, bool binary) const {
|
||||
WriteMarker(os,binary,Component::TypeToMarker(GetType()));
|
||||
WriteBasicType(os,binary,OutputDim());
|
||||
WriteBasicType(os,binary,InputDim());
|
||||
if(!binary) os << "\n";
|
||||
this->WriteData(os,binary);
|
||||
}
|
||||
|
||||
|
||||
} // namespace
|
|
@ -0,0 +1,273 @@
|
|||
// nnet/nnet-component.h
|
||||
|
||||
// Copyright 2011 Karel Vesely
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
|
||||
#ifndef KALDI_NNET_COMPONENT_H
|
||||
#define KALDI_NNET_COMPONENT_H
|
||||
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
#include "matrix/matrix-lib.h"
|
||||
//#include "nnet/nnet-nnet.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
//declare the nnet class so we can declare pointer
|
||||
class Nnet;
|
||||
|
||||
|
||||
/**
|
||||
* Abstract class, basic element of the network,
|
||||
* it is a box with defined inputs, outputs,
|
||||
* and tranformation functions interface.
|
||||
*
|
||||
* It is able to propagate and backpropagate
|
||||
* exact implementation is to be implemented in descendants.
|
||||
*
|
||||
* The data buffers are not included
|
||||
* and will be managed from outside.
|
||||
*/
|
||||
class Component
|
||||
{
|
||||
//////////////////////////////////////////////////////////////
|
||||
// Disable copy construction and assignment
|
||||
private:
|
||||
Component(Component&);
|
||||
Component& operator=(Component&);
|
||||
|
||||
//////////////////////////////////////////////////////////////
|
||||
// Polymorphic Component RTTI
|
||||
public:
|
||||
/// Types of the net components
|
||||
typedef enum {
|
||||
kUnknown = 0x0,
|
||||
|
||||
kUpdatableComponent = 0x0100,
|
||||
kBiasedLinearity,
|
||||
kSharedLinearity,
|
||||
|
||||
kActivationFunction = 0x0200,
|
||||
kSoftmax,
|
||||
kSigmoid,
|
||||
|
||||
kTranform = 0x0400,
|
||||
kExpand,
|
||||
kCopy,
|
||||
kTranspose,
|
||||
kBlockLinearity,
|
||||
kBias,
|
||||
kWindow,
|
||||
kLog
|
||||
} ComponentType;
|
||||
/// Pair of type and marker
|
||||
struct key_value {
|
||||
const Component::ComponentType key;
|
||||
const char* value;
|
||||
};
|
||||
/// Mapping of types and markers
|
||||
static const struct key_value kMarkerMap[];
|
||||
/// Convert component type to marker
|
||||
static const char* TypeToMarker(ComponentType t);
|
||||
/// Convert marker to component type
|
||||
static ComponentType MarkerToType(const std::string& s);
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////
|
||||
// Constructor & Destructor
|
||||
public:
|
||||
Component(MatrixIndexT input_dim, MatrixIndexT output_dim, Nnet* nnet)
|
||||
: input_dim_(input_dim), output_dim_(output_dim), nnet_(nnet)
|
||||
{ }
|
||||
virtual ~Component()
|
||||
{ }
|
||||
|
||||
//////////////////////////////////////////////////////////////
|
||||
// Public interface
|
||||
public:
|
||||
/// Get Type Identification of the component
|
||||
virtual ComponentType GetType() const = 0;
|
||||
/// Check if contains trainable parameters
|
||||
virtual bool IsUpdatable() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Get size of input vectors
|
||||
MatrixIndexT InputDim() const {
|
||||
return input_dim_;
|
||||
}
|
||||
/// Get size of output vectors
|
||||
MatrixIndexT OutputDim() const {
|
||||
return output_dim_;
|
||||
}
|
||||
|
||||
/// Perform forward pass propagateion Input->Output
|
||||
void Propagate(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out);
|
||||
/// Perform backward pass propagateion ErrorInput->ErrorOutput
|
||||
void Backpropagate(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err);
|
||||
|
||||
/// Read component from stream
|
||||
static Component* Read(std::istream& is, bool binary, Nnet* nnet);
|
||||
/// Write component to stream
|
||||
void Write(std::ostream& os, bool binary) const;
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////
|
||||
// abstract interface for propagation/backpropagation
|
||||
protected:
|
||||
/// Forward pass transformation (to be implemented by descendents...)
|
||||
virtual void PropagateFnc(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) = 0;
|
||||
/// Backward pass transformation (to be implemented by descendents...)
|
||||
virtual void BackpropagateFnc(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err) = 0;
|
||||
|
||||
/// Reads the component content
|
||||
virtual void ReadData(std::istream& is, bool binary)
|
||||
{ }
|
||||
/// Writes the component content
|
||||
virtual void WriteData(std::ostream& os, bool binary) const
|
||||
{ }
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////
|
||||
// data members
|
||||
protected:
|
||||
MatrixIndexT input_dim_; ///< Size of input vectors
|
||||
MatrixIndexT output_dim_; ///< Size of output vectors
|
||||
|
||||
Nnet* nnet_; ///< Pointer to the whole network
|
||||
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* Class UpdatableComponent is a Component which has
|
||||
* trainable parameters and contains some global
|
||||
* parameters for stochastic gradient descent
|
||||
* (learnrate,momenutm,L2,L1)
|
||||
*/
|
||||
class UpdatableComponent : public Component {
|
||||
//////////////////////////////////////////////////////////////
|
||||
// Constructor & Destructor
|
||||
public:
|
||||
UpdatableComponent(MatrixIndexT input_dim, MatrixIndexT output_dim, Nnet* nnet)
|
||||
: Component(input_dim,output_dim,nnet),
|
||||
learn_rate_(0.0), momentum_(0.0), l2_penalty_(0.0), l1_penalty_(0.0)
|
||||
{ }
|
||||
virtual ~UpdatableComponent()
|
||||
{ }
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////
|
||||
// Public interface
|
||||
public:
|
||||
/// Check if contains trainable parameters
|
||||
bool IsUpdatable() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Compute gradient and update parameters
|
||||
virtual void Update(const Matrix<BaseFloat>& input, const Matrix<BaseFloat>& err) = 0;
|
||||
|
||||
/// Sets the learning rate of gradient descent
|
||||
void LearnRate(BaseFloat lrate) {
|
||||
learn_rate_ = lrate;
|
||||
}
|
||||
/// Gets the learning rate of gradient descent
|
||||
BaseFloat LearnRate() {
|
||||
return learn_rate_;
|
||||
}
|
||||
|
||||
/// Sets momentum
|
||||
void Momentum(BaseFloat mmt) {
|
||||
momentum_ = mmt;
|
||||
}
|
||||
/// Gets momentum
|
||||
BaseFloat Momentum() {
|
||||
return momentum_;
|
||||
}
|
||||
|
||||
/// Sets L2 penalty (weight decay)
|
||||
void L2Penalty(BaseFloat l2) {
|
||||
l2_penalty_ = l2;
|
||||
}
|
||||
/// Gets L2 penalty (weight decay)
|
||||
BaseFloat L2Penalty() {
|
||||
return l2_penalty_;
|
||||
}
|
||||
|
||||
/// Sets L1 penalty (sparisity promotion)
|
||||
void L1Penalty(BaseFloat l1) {
|
||||
l1_penalty_ = l1;
|
||||
}
|
||||
/// Gets L1 penalty (sparisity promotion)
|
||||
BaseFloat L1Penalty() {
|
||||
return l1_penalty_;
|
||||
}
|
||||
|
||||
protected:
|
||||
BaseFloat learn_rate_; ///< learning rate (0.0..0.01)
|
||||
BaseFloat momentum_; ///< momentum value (0.0..1.0)
|
||||
BaseFloat l2_penalty_; ///< L2 regularization constant (0.0..1e-4)
|
||||
BaseFloat l1_penalty_; ///< L1 regularization constant (0.0..1e-4)
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// INLINE FUNCTIONS
|
||||
// Component::
|
||||
inline void Component::Propagate(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) {
|
||||
if(input_dim_ != in.NumCols()) {
|
||||
KALDI_ERR << "Nonmatching dims, component:" << input_dim_ << " data:" << in.NumCols();
|
||||
}
|
||||
|
||||
if(output_dim_ != out->NumCols() || in.NumRows() != out->NumRows()) {
|
||||
out->Resize(in.NumRows(), output_dim_);
|
||||
}
|
||||
|
||||
PropagateFnc(in, out);
|
||||
}
|
||||
|
||||
|
||||
inline void Component::Backpropagate(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err) {
|
||||
if(output_dim_ != in_err.NumCols()) {
|
||||
KALDI_ERR << "Nonmatching dims, component:" << output_dim_
|
||||
<< " data:" << in_err.NumCols();
|
||||
}
|
||||
|
||||
if(input_dim_ != out_err->NumCols() || in_err.NumRows() != out_err->NumRows()) {
|
||||
out_err->Resize(in_err.NumRows(), input_dim_);
|
||||
}
|
||||
|
||||
BackpropagateFnc(in_err, out_err);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// INLINE FUNCTIONS
|
||||
// UpdatableComponent::
|
||||
|
||||
// nothing for now!
|
||||
|
||||
|
||||
|
||||
} // namespace kaldi
|
||||
|
||||
|
||||
#endif
|
|
@ -0,0 +1,151 @@
|
|||
// nnet/nnet-loss.cc
|
||||
|
||||
// Copyright 2011 Karel Vesely
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "nnet/nnet-loss.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
|
||||
void Xent::Eval(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* diff) {
|
||||
KALDI_ASSERT(net_out.NumCols() == target.NumCols());
|
||||
KALDI_ASSERT(net_out.NumRows() == target.NumRows());
|
||||
diff->Resize(net_out.NumRows(),net_out.NumCols(),kUndefined);
|
||||
|
||||
//compute derivative wrt. activations of last layer of neurons
|
||||
diff->CopyFromMat(net_out);
|
||||
diff->AddMat(-1.0,target);
|
||||
|
||||
//we'll not produce per-frame classification accuracy for soft labels
|
||||
correct_ = -1;
|
||||
|
||||
//compute xentropy
|
||||
BaseFloat val;
|
||||
for(int32 r=0; r<net_out.NumRows(); r++) {
|
||||
for(int32 c=0; c<net_out.NumCols(); c++) {
|
||||
val = -target(r,c)*log(net_out(r,c));
|
||||
if(isinf(val)) val = 1e10;
|
||||
loss_ += val;
|
||||
}
|
||||
}
|
||||
|
||||
frames_ += net_out.NumRows();
|
||||
}
|
||||
|
||||
|
||||
void Xent::Eval(const Matrix<BaseFloat>& net_out, const std::vector<int32>& target, Matrix<BaseFloat>* diff) {
|
||||
KALDI_ASSERT(net_out.NumRows() == (int32)target.size());
|
||||
|
||||
//check the labels
|
||||
int32 max=0;
|
||||
std::vector<int32>::const_iterator it;
|
||||
for(it=target.begin(); it!=target.end(); ++it) {
|
||||
if(max < *it) max = *it;
|
||||
}
|
||||
if(max > net_out.NumCols()) {
|
||||
KALDI_ERR << "Network has " << net_out.NumCols()
|
||||
<< " outputs while having " << max << " labels";
|
||||
}
|
||||
|
||||
//compute derivative wrt. activations of last layer of neurons
|
||||
diff->Resize(net_out.NumRows(),net_out.NumCols(),kUndefined);
|
||||
diff->CopyFromMat(net_out);
|
||||
for(int32 r=0; r<(int32)target.size(); r++) {
|
||||
KALDI_ASSERT(target.at(r) <= diff->NumCols());
|
||||
(*diff)(r,target.at(r)-1) -= 1.0;
|
||||
}
|
||||
|
||||
//we'll not produce per-frame classification accuracy for soft labels
|
||||
correct_ += Correct(net_out,target);
|
||||
|
||||
//compute xentropy
|
||||
BaseFloat val;
|
||||
for(int32 r=0; r<net_out.NumRows(); r++) {
|
||||
KALDI_ASSERT(target.at(r) <= net_out.NumCols());
|
||||
val = -log(net_out(r,target.at(r)-1));
|
||||
if(isinf(val)) val = 1e10;
|
||||
loss_ += val;
|
||||
}
|
||||
|
||||
frames_ += net_out.NumRows();
|
||||
}
|
||||
|
||||
|
||||
std::string Xent::Report() {
|
||||
std::ostringstream oss;
|
||||
oss << "Xent:" << loss_ << " frames:" << frames_
|
||||
<< " err/frm:" << loss_/frames_;
|
||||
if(correct_ >= 0.0) {
|
||||
oss << " correct[" << 100.0*correct_/frames_ << "%]";
|
||||
}
|
||||
oss << std::endl;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
int32 Xent::Correct(const Matrix<BaseFloat>& net_out, const std::vector<int32>& target) {
|
||||
int32 correct = 0;
|
||||
for(int32 r=0; r<net_out.NumRows(); r++) {
|
||||
BaseFloat max = -1;
|
||||
int32 max_id = -1;
|
||||
for(int32 c=0; c<net_out.NumCols(); c++) {
|
||||
if(max < net_out(r,c)) {
|
||||
max = net_out(r,c);
|
||||
max_id = c;
|
||||
}
|
||||
}
|
||||
if(target.at(r)-1 == max_id) {
|
||||
correct++;
|
||||
}
|
||||
}
|
||||
return correct;
|
||||
}
|
||||
|
||||
|
||||
void Mse::Eval(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* diff) {
|
||||
KALDI_ASSERT(net_out.NumCols() == target.NumCols());
|
||||
KALDI_ASSERT(net_out.NumRows() == target.NumRows());
|
||||
diff->Resize(net_out.NumRows(),net_out.NumCols(),kUndefined);
|
||||
|
||||
//compute derivative w.r.t. neural nerwork outputs
|
||||
diff->CopyFromMat(net_out);
|
||||
diff->AddMat(-1.0,target);
|
||||
|
||||
//compute mean square error
|
||||
BaseFloat val;
|
||||
for(int32 r=0; r<net_out.NumRows(); r++) {
|
||||
for(int32 c=0; c<net_out.NumCols(); c++) {
|
||||
val = target(r,c) - net_out(r,c);
|
||||
loss_ += val*val;
|
||||
}
|
||||
}
|
||||
|
||||
frames_ += net_out.NumRows();
|
||||
}
|
||||
|
||||
|
||||
std::string Mse::Report() {
|
||||
std::ostringstream oss;
|
||||
oss << "Mse:" << loss_ << " frames:" << frames_
|
||||
<< " err/frm:" << loss_/frames_
|
||||
<< std::endl;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
} // namespace
|
|
@ -0,0 +1,75 @@
|
|||
// nnet/nnet-loss.h
|
||||
|
||||
// Copyright 2011 Karel Vesely
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef KALDI_NNET_LOSS_H
|
||||
#define KALDI_NNET_LOSS_H
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
#include "matrix/matrix-lib.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
class Xent {
|
||||
public:
|
||||
Xent()
|
||||
: frames_(0), correct_(0), loss_(0.0)
|
||||
{ }
|
||||
~Xent()
|
||||
{ }
|
||||
|
||||
/// Evaluate cross entropy from hard labels
|
||||
void Eval(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* diff);
|
||||
/// Evaluate cross entropy from soft labels
|
||||
void Eval(const Matrix<BaseFloat>& net_out, const std::vector<int32>& target, Matrix<BaseFloat>* diff);
|
||||
|
||||
/// Generate string with error report
|
||||
std::string Report();
|
||||
|
||||
private:
|
||||
int32 Correct(const Matrix<BaseFloat>& net_out, const std::vector<int32>& target);
|
||||
|
||||
private:
|
||||
int32 frames_;
|
||||
int32 correct_;
|
||||
double loss_;
|
||||
};
|
||||
|
||||
class Mse {
|
||||
public:
|
||||
Mse()
|
||||
: frames_(0), loss_(0.0)
|
||||
{ }
|
||||
~Mse()
|
||||
{ }
|
||||
|
||||
/// Evaluate mean square error from target values
|
||||
void Eval(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* diff);
|
||||
|
||||
/// Generate string with error report
|
||||
std::string Report();
|
||||
|
||||
private:
|
||||
int32 frames_;
|
||||
double loss_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,217 @@
|
|||
// nnet/nnet-nnet.cc
|
||||
|
||||
// Copyright 2011 Karel Vesely
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "nnet/nnet-nnet.h"
|
||||
#include "nnet/nnet-component.h"
|
||||
#include "nnet/nnet-activation.h"
|
||||
#include "nnet/nnet-biasedlinearity.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
void Nnet::Propagate(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) {
|
||||
KALDI_ASSERT(NULL != out);
|
||||
|
||||
if(LayerCount() == 0) {
|
||||
out->Resize(in.NumRows(),in.NumCols(),kUndefined);
|
||||
out->CopyFromMat(in);
|
||||
return;
|
||||
}
|
||||
|
||||
//we need at least L+1 input buffers
|
||||
KALDI_ASSERT((int32)propagate_buf_.size() >= LayerCount()+1);
|
||||
|
||||
|
||||
propagate_buf_[0].Resize(in.NumRows(),in.NumCols(),kUndefined);
|
||||
propagate_buf_[0].CopyFromMat(in);
|
||||
|
||||
for(int32 i=0; i<(int32)nnet_.size(); i++) {
|
||||
nnet_[i]->Propagate(propagate_buf_[i],&propagate_buf_[i+1]);
|
||||
}
|
||||
|
||||
Matrix<BaseFloat>& mat = propagate_buf_[nnet_.size()];
|
||||
out->Resize(mat.NumRows(),mat.NumCols(),kUndefined);
|
||||
out->CopyFromMat(mat);
|
||||
}
|
||||
|
||||
|
||||
void Nnet::Backpropagate(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err) {
|
||||
if(LayerCount() == 0) { KALDI_ERR << "Cannot backpropagate on empty network"; }
|
||||
|
||||
//we need at least L+1 input bufers
|
||||
KALDI_ASSERT((int32)propagate_buf_.size() >= LayerCount()+1);
|
||||
//we need at least L-1 error bufers
|
||||
KALDI_ASSERT((int32)backpropagate_buf_.size() >= LayerCount()-1);
|
||||
|
||||
//find out when we can stop backprop
|
||||
int32 backprop_stop = -1;
|
||||
if(NULL == out_err) {
|
||||
backprop_stop++;
|
||||
while(1) {
|
||||
if(nnet_[backprop_stop]->IsUpdatable()) {
|
||||
if(0.0 != dynamic_cast<UpdatableComponent*>(nnet_[backprop_stop])->LearnRate()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
backprop_stop++;
|
||||
if(backprop_stop == (int32)nnet_.size()) {
|
||||
KALDI_ERR << "All layers have zero learning rate!";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
//disable!
|
||||
backprop_stop=-1;
|
||||
|
||||
//////////////////////////////////////
|
||||
// Backpropagation
|
||||
//
|
||||
|
||||
//don't copy the in_err to buffers, use it as is...
|
||||
int32 i = nnet_.size()-1;
|
||||
if(nnet_[i]->IsUpdatable()) {
|
||||
UpdatableComponent* uc = dynamic_cast<UpdatableComponent*>(nnet_[i]);
|
||||
if(uc->LearnRate() > 0.0) {
|
||||
uc->Update(propagate_buf_[i],in_err);
|
||||
}
|
||||
}
|
||||
nnet_.back()->Backpropagate(in_err,&backpropagate_buf_[i-1]);
|
||||
|
||||
//backpropagate by using buffers
|
||||
for(i--; i >= 1; i--) {
|
||||
if(nnet_[i]->IsUpdatable()) {
|
||||
UpdatableComponent* uc = dynamic_cast<UpdatableComponent*>(nnet_[i]);
|
||||
if(uc->LearnRate() > 0.0) {
|
||||
uc->Update(propagate_buf_[i],backpropagate_buf_[i]);
|
||||
}
|
||||
}
|
||||
if(backprop_stop == i) break;
|
||||
nnet_[i]->Backpropagate(backpropagate_buf_[i],&backpropagate_buf_[i-1]);
|
||||
}
|
||||
|
||||
//update first layer
|
||||
if(nnet_[0]->IsUpdatable() && 0 >= backprop_stop) {
|
||||
UpdatableComponent* uc = dynamic_cast<UpdatableComponent*>(nnet_[0]);
|
||||
if(uc->LearnRate() > 0.0) {
|
||||
uc->Update(propagate_buf_[0],backpropagate_buf_[0]);
|
||||
}
|
||||
}
|
||||
//now backpropagate through first layer, but only if asked to (by out_err pointer)
|
||||
if(NULL != out_err) {
|
||||
nnet_[0]->Backpropagate(backpropagate_buf_[0],out_err);
|
||||
}
|
||||
|
||||
//
|
||||
// End of Backpropagation
|
||||
//////////////////////////////////////
|
||||
}
|
||||
|
||||
|
||||
void Nnet::Feedforward(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) {
|
||||
KALDI_ASSERT(NULL != out);
|
||||
|
||||
if(LayerCount() == 0) {
|
||||
out->Resize(in.NumRows(),in.NumCols(),kUndefined);
|
||||
out->CopyFromMat(in);
|
||||
return;
|
||||
}
|
||||
|
||||
//we need at least 2 input buffers
|
||||
KALDI_ASSERT(propagate_buf_.size() >= 2);
|
||||
|
||||
//propagate by using exactly 2 auxiliary buffers
|
||||
int32 L = 0;
|
||||
nnet_[L]->Propagate(in,&propagate_buf_[L%2]);
|
||||
for(L++; L<=LayerCount()-2; L++) {
|
||||
nnet_[L]->Propagate(propagate_buf_[(L-1)%2],&propagate_buf_[L%2]);
|
||||
}
|
||||
nnet_[L]->Propagate(propagate_buf_[(L-1)%2],out);
|
||||
}
|
||||
|
||||
|
||||
void Nnet::Read(std::istream& in, bool binary) {
|
||||
//get the network layers from a factory
|
||||
Component *comp;
|
||||
while(NULL != (comp = Component::Read(in,binary,this))) {
|
||||
if(LayerCount() > 0 && nnet_.back()->OutputDim() != comp->InputDim()) {
|
||||
KALDI_ERR << "Dimensionality mismatch!"
|
||||
<< " Previous layer output:" << nnet_.back()->OutputDim()
|
||||
<< " Current layer input:" << comp->InputDim();
|
||||
}
|
||||
nnet_.push_back(comp);
|
||||
}
|
||||
//create empty buffers
|
||||
propagate_buf_.resize(LayerCount()+1);
|
||||
backpropagate_buf_.resize(LayerCount()-1);
|
||||
//reset learn rate
|
||||
learn_rate_ = 0.0;
|
||||
}
|
||||
|
||||
|
||||
void Nnet::LearnRate(BaseFloat lrate, const char* lrate_factors) {
|
||||
//split lrate_factors to a vector
|
||||
std::vector<BaseFloat> lrate_factor_vec;
|
||||
if(NULL != lrate_factors) {
|
||||
char* copy = new char[strlen(lrate_factors)+1];
|
||||
strcpy(copy, lrate_factors);
|
||||
char* tok = NULL;
|
||||
while(NULL != (tok = strtok((tok==NULL?copy:NULL),",:; "))) {
|
||||
lrate_factor_vec.push_back(atof(tok));
|
||||
}
|
||||
delete copy;
|
||||
}
|
||||
|
||||
//count trainable layers
|
||||
int32 updatable = 0;
|
||||
for(int i=0; i<LayerCount(); i++) {
|
||||
if(nnet_[i]->IsUpdatable()) updatable++;
|
||||
}
|
||||
//check number of factors
|
||||
if(lrate_factor_vec.size() > 0 && updatable != (int32)lrate_factor_vec.size()) {
|
||||
KALDI_ERR << "Mismatch between number of trainable layers " << updatable
|
||||
<< " and learn rate factors " << lrate_factor_vec.size();
|
||||
}
|
||||
|
||||
//set learn rates
|
||||
updatable=0;
|
||||
for(int32 i=0; i<LayerCount(); i++) {
|
||||
if(nnet_[i]->IsUpdatable()) {
|
||||
BaseFloat lrate_scaled = lrate;
|
||||
if(lrate_factor_vec.size() > 0) lrate_scaled *= lrate_factor_vec[updatable++];
|
||||
dynamic_cast<UpdatableComponent*>(nnet_[i])->LearnRate(lrate_scaled);
|
||||
}
|
||||
}
|
||||
//set global learn rate
|
||||
learn_rate_ = lrate;
|
||||
}
|
||||
|
||||
|
||||
std::string Nnet::LearnRateString() {
|
||||
std::ostringstream oss;
|
||||
oss << "LEARN_RATE global: " << learn_rate_ << " individual: ";
|
||||
for(int32 i=0; i<LayerCount(); i++) {
|
||||
if(nnet_[i]->IsUpdatable()) {
|
||||
oss << dynamic_cast<UpdatableComponent*>(nnet_[i])->LearnRate() << " ";
|
||||
}
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
} // namespace
|
|
@ -0,0 +1,223 @@
|
|||
// nnet/nnet-nnet.h
|
||||
|
||||
// Copyright 2011 Karel Vesely
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
|
||||
#ifndef KALDI_NNET_NNET_H
|
||||
#define KALDI_NNET_NNET_H
|
||||
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
#include "util/kaldi-io.h"
|
||||
#include "matrix/matrix-lib.h"
|
||||
#include "nnet/nnet-component.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
class Nnet {
|
||||
//////////////////////////////////////
|
||||
// Typedefs
|
||||
typedef std::vector<Component*> NnetType;
|
||||
|
||||
//////////////////////////////////////////////////////////////
|
||||
// Disable copy construction and assignment
|
||||
private:
|
||||
Nnet(Nnet&);
|
||||
Nnet& operator=(Nnet&);
|
||||
|
||||
//////////////////////////////////////////////////////////////
|
||||
// Constructor & Destructor
|
||||
public:
|
||||
Nnet()
|
||||
{ }
|
||||
|
||||
~Nnet(); //{ } later...
|
||||
|
||||
//////////////////////////////////////////////////////////////
|
||||
// Public interface
|
||||
public:
|
||||
/// Perform forward pass through the network
|
||||
void Propagate(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out);
|
||||
/// Perform backward pass through the network
|
||||
void Backpropagate(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err);
|
||||
/// Perform forward pass through the network, don't keep buffers (use it when not training)
|
||||
void Feedforward(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out);
|
||||
|
||||
MatrixIndexT InputDim() const; ///< Dimensionality of the input features
|
||||
MatrixIndexT OutputDim() const; ///< Dimensionality of the desired vectors
|
||||
|
||||
MatrixIndexT LayerCount() const { ///< Get number of layers
|
||||
return nnet_.size();
|
||||
}
|
||||
Component* Layer(MatrixIndexT index) { ///< Access to individual layer
|
||||
return nnet_[index];
|
||||
}
|
||||
int IndexOfLayer(const Component& comp) const; ///< Get the position of layer in network
|
||||
|
||||
/// Access to forward pass buffers
|
||||
const std::vector<Matrix<BaseFloat> >& PropagateBuffer() const {
|
||||
return propagate_buf_;
|
||||
}
|
||||
|
||||
/// Access to backward pass buffers
|
||||
const std::vector<Matrix<BaseFloat> >& BackpropagateBuffer() const {
|
||||
return backpropagate_buf_;
|
||||
}
|
||||
|
||||
/// Read the MLP from file (can add layers to exisiting instance of Nnet)
|
||||
void Read(const std::string& file);
|
||||
/// Read the MLP from stream (can add layers to exisiting instance of Nnet)
|
||||
void Read(std::istream& in, bool binary);
|
||||
/// Write MLP to file
|
||||
void Write(const std::string& file, bool binary);
|
||||
/// Write MLP to stream
|
||||
void Write(std::ostream& out, bool binary);
|
||||
|
||||
/// Set the learning rate values to trainable layers,
|
||||
/// factors can disable training of individual layers
|
||||
void LearnRate(BaseFloat lrate, const char* lrate_factors);
|
||||
/// Get the global learning rate value
|
||||
BaseFloat LearnRate() {
|
||||
return learn_rate_;
|
||||
}
|
||||
/// Get the string with real learning rate values
|
||||
std::string LearnRateString();
|
||||
|
||||
void Momentum(BaseFloat mmt);
|
||||
void L2Penalty(BaseFloat l2);
|
||||
void L1Penalty(BaseFloat l1);
|
||||
|
||||
//////////////////////////////////////////////////////////////
|
||||
// Private interface
|
||||
private:
|
||||
/// Creates a component by reading from stream, return NULL if no more components
|
||||
static Component* ComponentFactory(std::istream& in, bool binary, Nnet* nnet);
|
||||
/// Dumps individual component to stream
|
||||
static void ComponentDumper(std::ostream& out, bool binary, const Component& comp);
|
||||
|
||||
private:
|
||||
NnetType nnet_; ///< vector of all Component*, represents layers
|
||||
|
||||
std::vector<Matrix<BaseFloat> > propagate_buf_; ///< buffers for forward pass
|
||||
std::vector<Matrix<BaseFloat> > backpropagate_buf_; ///< buffers for backward pass
|
||||
|
||||
BaseFloat learn_rate_; ///< global learning rate
|
||||
};
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// INLINE FUNCTIONS
|
||||
// Nnet::
|
||||
inline Nnet::~Nnet() {
|
||||
//delete all the components
|
||||
NnetType::iterator it;
|
||||
for(it=nnet_.begin(); it!=nnet_.end(); ++it) {
|
||||
delete *it;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline MatrixIndexT Nnet::InputDim() const {
|
||||
if(LayerCount() > 0) {
|
||||
return nnet_.front()->InputDim();
|
||||
} else {
|
||||
KALDI_ERR << "No layers in MLP";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline MatrixIndexT Nnet::OutputDim() const {
|
||||
if(LayerCount() > 0) {
|
||||
return nnet_.back()->OutputDim();
|
||||
} else {
|
||||
KALDI_ERR << "No layers in MLP";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline int32 Nnet::IndexOfLayer(const Component& comp) const {
|
||||
for(int32 i=0; i<LayerCount(); i++) {
|
||||
if(&comp == nnet_[i]) return i;
|
||||
}
|
||||
KALDI_ERR << "Component:" << &comp
|
||||
<< " type:" << comp.GetType()
|
||||
<< " not found in the MLP";
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
inline void Nnet::Read(const std::string& file) {
|
||||
bool binary;
|
||||
Input in(file,&binary);
|
||||
Read(in.Stream(),binary);
|
||||
in.Close();
|
||||
}
|
||||
|
||||
|
||||
inline void Nnet::Write(const std::string& file, bool binary) {
|
||||
Output out(file, binary, true);
|
||||
Write(out.Stream(),binary);
|
||||
out.Close();
|
||||
}
|
||||
|
||||
|
||||
inline void Nnet::Write(std::ostream& out, bool binary) {
|
||||
for(int32 i=0; i<LayerCount(); i++) {
|
||||
nnet_[i]->Write(out,binary);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline void Nnet::Momentum(BaseFloat mmt) {
|
||||
for(int32 i=0; i<LayerCount(); i++) {
|
||||
if(nnet_[i]->IsUpdatable()) {
|
||||
dynamic_cast<UpdatableComponent*>(nnet_[i])->Momentum(mmt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline void Nnet::L2Penalty(BaseFloat l2) {
|
||||
for(int32 i=0; i<LayerCount(); i++) {
|
||||
if(nnet_[i]->IsUpdatable()) {
|
||||
dynamic_cast<UpdatableComponent*>(nnet_[i])->L2Penalty(l2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline void Nnet::L1Penalty(BaseFloat l1) {
|
||||
for(int32 i=0; i<LayerCount(); i++) {
|
||||
if(nnet_[i]->IsUpdatable()) {
|
||||
dynamic_cast<UpdatableComponent*>(nnet_[i])->L1Penalty(l1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
} //namespace kaldi
|
||||
|
||||
#endif
|
||||
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
// nnet/nnet-test.cc
|
||||
|
||||
// Copyright 2010 Karel Vesely
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "base/kaldi-common.h"
|
||||
#include "nnet/nnet-component.h"
|
||||
#include "nnet/nnet-nnet.h"
|
||||
|
||||
using namespace kaldi;
|
||||
|
||||
static void UnitTestSomething() {
|
||||
KALDI_ERR << "Unimeplemented";
|
||||
}
|
||||
|
||||
|
||||
static void UnitTestNnet() {
|
||||
try {
|
||||
UnitTestSomething();
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
UnitTestNnet();
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
|
||||
all:
|
||||
EXTRA_CXXFLAGS = -Wno-sign-compare
|
||||
include ../kaldi.mk
|
||||
|
||||
BINFILES = nnet-train-xent-hardlab-perutt
|
||||
|
||||
OBJFILES =
|
||||
|
||||
all: $(BINFILES)
|
||||
|
||||
|
||||
TESTFILES =
|
||||
|
||||
|
||||
$(BINFILES): ../nnet/kaldi-nnet.a ../matrix/kaldi-matrix.a ../util/kaldi-util.a ../base/kaldi-base.a
|
||||
|
||||
|
||||
|
||||
# Rule below would expand to, e.g.:
|
||||
# ../base/kaldi-base.a:
|
||||
# make -c ../base kaldi-base.a
|
||||
# -c option to make is same as changing directory.
|
||||
%.a:
|
||||
$(MAKE) -C ${@D} ${@F}
|
||||
|
||||
clean:
|
||||
rm *.o *.a $(TESTFILES) $(BINFILES)
|
||||
|
||||
test: $(TESTFILES)
|
||||
for x in $(TESTFILES); do ./$$x >&/dev/null || (echo "***test $$x failed***"; exit 1); done
|
||||
echo Tests succeeded
|
||||
|
||||
.valgrind: $(TESTFILES)
|
||||
|
||||
|
||||
depend:
|
||||
-$(CXX) -M $(CXXFLAGS) *.cc > .depend.mk
|
||||
|
||||
# removing automatic making of "depend" as it's quite slow.
|
||||
#.depend.mk: depend
|
||||
|
||||
-include .depend.mk
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
// nnet/nnet-train-xent-hardlab-perutt.cc
|
||||
|
||||
// Copyright 2011 Karel Vesely
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "nnet/nnet-nnet.h"
|
||||
#include "nnet/nnet-loss.h"
|
||||
#include "base/kaldi-common.h"
|
||||
#include "util/common-utils.h"
|
||||
#include "util/timer.h"
|
||||
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
using namespace kaldi;
|
||||
try {
|
||||
const char *usage =
|
||||
"Perform iteration of Neural Network training by stochastic gradient descent.\n"
|
||||
"Usage: nnet-train-xent-hardlab-perutt [options] <model-in> <feature-rspecifier> <alignments-rspecifier> [<model-out>]\n"
|
||||
"e.g.: \n"
|
||||
" nnet-train-xent-hardlab-perutt nnet.init scp:train.scp ark:train.ali nnet.iter1\n";
|
||||
|
||||
ParseOptions po(usage);
|
||||
bool binary = false,
|
||||
crossvalidate = false;
|
||||
po.Register("binary", &binary, "Write output in binary mode");
|
||||
po.Register("cross-validate", &crossvalidate, "Perform cross-validation (don't backpropagate)");
|
||||
|
||||
BaseFloat learn_rate = 0.008,
|
||||
momentum = 0.0,
|
||||
l2_penalty = 0.0,
|
||||
l1_penalty = 0.0;
|
||||
|
||||
po.Register("learn-rate", &learn_rate, "Learning rate");
|
||||
po.Register("momentum", &momentum, "Momentum");
|
||||
po.Register("l2-penalty", &l2_penalty, "L2 penalty (weight decay)");
|
||||
po.Register("l1-penalty", &l1_penalty, "L1 penalty (promote sparsity)");
|
||||
|
||||
std::string feature_transform;
|
||||
po.Register("feature-transform", &feature_transform, "Feature transform Neural Network");
|
||||
|
||||
po.Read(argc, argv);
|
||||
|
||||
if (po.NumArgs() != 4-(crossvalidate?1:0)) {
|
||||
po.PrintUsage();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string model_filename = po.GetArg(1),
|
||||
feature_rspecifier = po.GetArg(2),
|
||||
alignments_rspecifier = po.GetArg(3);
|
||||
|
||||
std::string target_model_filename;
|
||||
if(!crossvalidate) {
|
||||
target_model_filename = po.GetArg(4);
|
||||
}
|
||||
|
||||
|
||||
using namespace kaldi;
|
||||
typedef kaldi::int32 int32;
|
||||
|
||||
|
||||
Nnet nnet_transf;
|
||||
if(feature_transform != "") {
|
||||
nnet_transf.Read(feature_transform);
|
||||
}
|
||||
|
||||
Nnet nnet;
|
||||
nnet.Read(model_filename);
|
||||
|
||||
nnet.LearnRate(learn_rate,NULL);
|
||||
nnet.Momentum(momentum);
|
||||
nnet.L2Penalty(l2_penalty);
|
||||
nnet.L1Penalty(l1_penalty);
|
||||
|
||||
kaldi::int64 tot_t = 0;
|
||||
|
||||
SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
|
||||
RandomAccessInt32VectorReader alignments_reader(alignments_rspecifier);
|
||||
|
||||
Xent xent;
|
||||
|
||||
Matrix<BaseFloat> feats_transf, nnet_out, glob_err;
|
||||
|
||||
Timer tim;
|
||||
KALDI_LOG << (crossvalidate?"CROSSVALIDATE":"TRAINING") << " STARTED";
|
||||
|
||||
int32 num_done = 0, num_no_alignment = 0, num_other_error = 0;
|
||||
for (; !feature_reader.Done(); feature_reader.Next()) {
|
||||
std::string key = feature_reader.Key();
|
||||
if (!alignments_reader.HasKey(key)) {
|
||||
num_no_alignment++;
|
||||
} else {
|
||||
const Matrix<BaseFloat> &mat = feature_reader.Value();
|
||||
const std::vector<int32> &alignment = alignments_reader.Value(key);
|
||||
|
||||
//std::cout << mat;
|
||||
|
||||
if ((int32)alignment.size() != mat.NumRows()) {
|
||||
KALDI_WARN << "Alignment has wrong size "<< (alignment.size()) << " vs. "<< (mat.NumRows());
|
||||
num_other_error++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if(num_done % 10000 == 0) std::cout << num_done << ", " << std::flush;
|
||||
num_done++;
|
||||
|
||||
nnet_transf.Feedforward(mat,&feats_transf);
|
||||
nnet.Propagate(feats_transf,&nnet_out);
|
||||
//std::cout << "\nNETOUT" << nnet_out;
|
||||
xent.Eval(nnet_out,alignment,&glob_err);
|
||||
//std::cout << "\nALIGN" << alignment[0] << " "<< alignment[1]<< " "<< alignment[2];
|
||||
//std::cout << "\nGLOBERR" << glob_err;
|
||||
if(!crossvalidate) {
|
||||
nnet.Backpropagate(glob_err,NULL);
|
||||
}
|
||||
|
||||
tot_t += mat.NumRows();
|
||||
}
|
||||
}
|
||||
|
||||
if(!crossvalidate) {
|
||||
nnet.Write(target_model_filename,binary);
|
||||
}
|
||||
|
||||
std::cout << "\n" << std::flush;
|
||||
|
||||
KALDI_LOG << (crossvalidate?"CROSSVALIDATE":"TRAINING") << " FINISHED "
|
||||
<< tim.Elapsed() << "s, fps" << tot_t/tim.Elapsed();
|
||||
|
||||
KALDI_LOG << "Done " << num_done << " files, " << num_no_alignment
|
||||
<< " with no alignments, " << num_other_error
|
||||
<< " with other errors.";
|
||||
|
||||
KALDI_LOG << xent.Report();
|
||||
|
||||
|
||||
return 0;
|
||||
} catch(const std::exception& e) {
|
||||
std::cerr << e.what();
|
||||
return -1;
|
||||
}
|
||||
}
|
Загрузка…
Ссылка в новой задаче