diff --git a/egs/rm/README.txt b/egs/rm/README.txt index 8b812cca1..371662156 100644 --- a/egs/rm/README.txt +++ b/egs/rm/README.txt @@ -5,4 +5,7 @@ scripts for a sequence of experiments. s1: This setup is experiments with GMM-based systems with various Maximum Likelihood techniques including global and speaker-specific transforms. - See a parallel setup in ../wsj/s1 + See a parallel setup in ../wsj/s1 + + s2: This setup is experiment with hybrid MLP system trained by + stochastic gradient descent. diff --git a/egs/rm/s2/NOTES b/egs/rm/s2/NOTES new file mode 100644 index 000000000..e69de29bb diff --git a/egs/rm/s2/RESULTS b/egs/rm/s2/RESULTS new file mode 100644 index 000000000..e69de29bb diff --git a/egs/rm/s2/conf/mfcc.conf b/egs/rm/s2/conf/mfcc.conf new file mode 100644 index 000000000..736150909 --- /dev/null +++ b/egs/rm/s2/conf/mfcc.conf @@ -0,0 +1 @@ +--use-energy=false # only non-default option. diff --git a/egs/rm/s2/conf/topo.proto b/egs/rm/s2/conf/topo.proto new file mode 100644 index 000000000..14a6da739 --- /dev/null +++ b/egs/rm/s2/conf/topo.proto @@ -0,0 +1,22 @@ + + + +NONSILENCEPHONES + + 0 0 0 0.75 1 0.25 + 1 1 1 0.75 2 0.25 + 2 2 2 0.75 3 0.25 + 3 + + + +SILENCEPHONES + + 0 0 0 0.25 1 0.25 2 0.25 3 0.25 + 1 1 1 0.25 2 0.25 3 0.25 4 0.25 + 2 2 1 0.25 2 0.25 3 0.25 4 0.25 + 3 3 1 0.25 2 0.25 3 0.25 4 0.25 + 4 4 4 0.25 5 0.75 + 5 + + diff --git a/egs/rm/s2/data_prep/make_trans.pl b/egs/rm/s2/data_prep/make_trans.pl new file mode 100755 index 000000000..2ad181dfd --- /dev/null +++ b/egs/rm/s2/data_prep/make_trans.pl @@ -0,0 +1,69 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# usage: make_trans.sh prefix in.flist input.snr out.txt out.scp + +# prefix is first letters of the database "key" (rest are numeric) + +# in.flist is just a list of filenames, probably of .sph files. +# input.snr is an snr format file from the RM dataset. +# out.txt is the output transcriptions in format "key word1 word\n" +# out.scp is the output scp file, which is as in.scp but has the +# database-key first on each line. + +# Reads from first argument e.g. $rootdir/rm1_audio1/rm1/doc/al_sents.snr +# and second argument train_wav.scp +# Writes to standard output trans.txt + +if(@ARGV != 5) { + die "usage: make_trans.sh prefix in.flist input.snr out.txt out.scp\n"; +} +($prefix, $in_flist, $input_snr, $out_txt, $out_scp) = @ARGV; + +open(F, "<$input_snr") || die "Opening SNOR file $input_snr"; + +while() { + if(m/^;/) { next; } + m/(.+) \((.+)\)/ || die "bad line $_"; + $T{$2} = $1; +} + +close(F); +open(G, "<$in_flist") || die "Opening file list $in_flist"; + +open(O, ">$out_txt") || die "Open output transcription file $out_txt"; + +open(P, ">$out_scp") || die "Open output scp file $out_scp"; + +while() { + $_ =~ m:/(\w+)/(\w+)\.sph\s+$:i || die "bad scp line $_"; + $spkname = $1; + $uttname = $2; + $uttname =~ tr/a-z/A-Z/; + defined $T{$uttname} || die "no trans for sent $uttname"; + $spkname =~ s/_//g; # remove underscore from spk name to make key nicer. + $key = $prefix . "_" . $spkname . "_" . $uttname; + $key =~ tr/A-Z/a-z/; # Make it all lower case. + # to make the numerical and string-sorted orders the same. + print O "$key $T{$uttname}\n"; + print P "$key $_"; + $n++; +} +close(O) || die "Closing output."; +close(P) || die "Closing output."; + + diff --git a/egs/rm/s2/data_prep/run.sh b/egs/rm/s2/data_prep/run.sh new file mode 100755 index 000000000..38e6b43d1 --- /dev/null +++ b/egs/rm/s2/data_prep/run.sh @@ -0,0 +1,92 @@ +# This script should be run from the directory where it is located (i.e. data_prep) +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# The input is the 3 CDs from the LDC distribution of Resource Management. +# The script's argument is a directory which has three subdirectories: +# rm1_audio1 rm1_audio2 rm2_audio + +if [ $# != 1 ]; then + echo "Usage: ./run.sh /path/to/RM" + exit 1; +fi + +RMROOT=$1 +if [ ! -d $RMROOT/rm1_audio1 -o ! -d $RMROOT/rm1_audio2 ]; then + echo "Error: run.sh requires a directory argument that contains rm1_audio1 and rm1_audio2" + exit 1; +fi + +if [ ! -d $RMROOT/rm2_audio ]; then + echo "**Warning: $RMROOT/rm2_audio does not exist; won't create spk2gender.map file correctly***" + sleep 1 +fi + +( + find $RMROOT/rm1_audio1/rm1/ind_trn -iname '*.sph'; + find $RMROOT/rm1_audio2/2_4_2/rm1/ind/dev_aug -iname '*.sph'; +) | perl -ane ' m:/sa\d.sph:i || m:/sb\d\d.sph:i || print; ' > train_sph.flist + + + +# make_trans.pl also creates the utterance id's and the kaldi-format scp file. +./make_trans.pl trn train_sph.flist $RMROOT/rm1_audio1/rm1/doc/al_sents.snr train_trans.txt train_sph.scp +mv train_trans.txt tmp; sort -k 1 tmp > train_trans.txt +mv train_sph.scp tmp; sort -k 1 tmp > train_sph.scp + +sph2pipe=`cd ../../../..; echo $PWD/tools/sph2pipe_v2.5/sph2pipe` +if [ ! -f $sph2pipe ]; then + echo "Could not find the sph2pipe program at $sph2pipe"; + exit 1; +fi +awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < train_sph.scp > train_wav.scp + +cat train_wav.scp | perl -ane 'm/^(\w+_(\w+)\w_\w+) / || die; print "$1 $2\n"' > train.utt2spk +cat train.utt2spk | sort -k 2 | ../scripts/utt2spk_to_spk2utt.pl > train.spk2utt + + +for ntest in 1_mar87 2_oct87 4_feb89 5_oct89 6_feb91 7_sep92; do + n=`echo $ntest | cut -d_ -f 1` + test=`echo $ntest | cut -d_ -f 2` + root=$RMROOT/rm1_audio2/2_4_2 + for x in `grep -v ';' $root/rm1/doc/tests/$ntest/${n}_indtst.ndx`; do + echo "$root/$x "; + done > test_${test}_sph.flist +done + +# make_trans.pl also creates the utterance id's and the kaldi-format scp file. +for test in mar87 oct87 feb89 oct89 feb91 sep92; do + ./make_trans.pl ${test} test_${test}_sph.flist $RMROOT/rm1_audio1/rm1/doc/al_sents.snr test_${test}_trans.txt test_${test}_sph.scp + mv test_${test}_trans.txt tmp; sort -k 1 tmp > test_${test}_trans.txt + mv test_${test}_sph.scp tmp; sort -k 1 tmp > test_${test}_sph.scp + + awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < test_${test}_sph.scp > test_${test}_wav.scp + + cat test_${test}_wav.scp | perl -ane 'm/^(\w+_(\w+)\w_\w+) / || die; print "$1 $2\n"' > test_${test}.utt2spk + cat test_${test}.utt2spk | sort -k 2 | ../scripts/utt2spk_to_spk2utt.pl > test_${test}.spk2utt +done + +cat $RMROOT/rm1_audio2/2_5_1/rm1/doc/al_spkrs.txt \ + $RMROOT/rm2_audio/3-1.2/rm2/doc/al_spkrs.txt | \ + perl -ane 'tr/A-Z/a-z/;print;' | grep -v ';' | \ + awk '{print $1, $2}' > spk2gender.map + +../scripts/make_rm_lm.pl $RMROOT/rm1_audio1/rm1/doc/wp_gram.txt > G.txt + +# Getting lexicon +../scripts/make_rm_dict.pl $RMROOT/rm1_audio2/2_4_2/score/src/rdev/pcdsril.txt > lexicon.txt + +echo Succeeded. diff --git a/egs/rm/s2/data_prep/sph2wav.sh b/egs/rm/s2/data_prep/sph2wav.sh new file mode 100644 index 000000000..841a357eb --- /dev/null +++ b/egs/rm/s2/data_prep/sph2wav.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +fake=false +if [ "$1" == "--fake" ]; then + fake=true + shift +fi + +sphdir=$1 # e.g. /mnt/matylda2/data/RM +wavdir=$2 # e.g. /mnt/matylda6/jhu09/qpovey/kaldi_rm_wav +flistin=$3 # e.g. train_sph.flist, contains sph files in sphdir +flistout=$4 # e.g. train_wav.flist, contains wav files in wavdir + + +if [ $fake == false ]; then + for x in `cat $flistin`; do + y=`echo $x | sed s:$sphdir:$wavdir: | sed s:.sph:.wav:`; + mkdir -p `dirname $y` + ../../tools/sph2pipe_v2.5/sph2pipe -f wav $x $y || exit 1; + done +fi + +cat $flistin | sed s:$sphdir:$wavdir: | sed s:.sph:.wav: > $flistout || exit 1; + diff --git a/egs/rm/s2/path.sh b/egs/rm/s2/path.sh new file mode 100755 index 000000000..b2ce66cf5 --- /dev/null +++ b/egs/rm/s2/path.sh @@ -0,0 +1 @@ +export PATH=$PATH:../../../src/bin:../../../tools/openfst/bin:../../../src/fstbin/:../../../src/gmmbin/:../../../src/featbin/:../../../src/fgmmbin:../../../src/sgmmbin:../../../src/nnetbin diff --git a/egs/rm/s2/run.sh b/egs/rm/s2/run.sh new file mode 100644 index 000000000..471f5d838 --- /dev/null +++ b/egs/rm/s2/run.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +#exit 1 # Don't run this... it's to be run line by line from the shell. + +# This script file cannot be run as-is; some paths in it need to be changed +# before you can run it. +# Search for /path/to. +# It is recommended that you do not invoke this file from the shell, but +# run the paths one by one, by hand. + +# the step in data_prep/ will need to be modified for your system. + +# First step is to do data preparation: +# This just creates some text files, it is fast. +# If not on the BUT system, you would have to change run.sh to reflect +# your own paths. +# + +#Example arguments to run.sh: /mnt/matylda2/data/RM, /ais/gobi2/speech/RM, /cygdrive/e/data/RM +# RM is a directory with subdirectories rm1_audio1, rm1_audio2, rm2_audio +cd data_prep +#*** You have to change the pathname below.*** +./run.sh /path/to/RM +cd .. + +mkdir -p data +( cd data; cp ../data_prep/{train,test*}.{spk2utt,utt2spk} . ; cp ../data_prep/spk2gender.map . ) + +# This next step converts the lexicon, grammar, etc., into FST format. +steps/prepare_graphs.sh + + +# Next, make sure that "exp/" is someplace you can write a significant amount of +# data to (e.g. make it a link to a file on some reasonably large file system). +# If it doesn't exist, the scripts below will make the directory "exp". + +# mfcc should be set to some place to put training mfcc's +# where you have space. +#e.g.: mfccdir=/mnt/matylda6/jhu09/qpovey/kaldi_rm_mfccb +mfccdir=/path/to/mfccdir +steps/make_mfcc_train.sh $mfccdir +steps/make_mfcc_test.sh $mfccdir + + + +# first, we will train monophone GMM system to get training labels +steps/train_mono.sh +steps/decode_mono.sh & + + + +# Now we train the MLP, +# it will have CMVN normalized MFCCs as input and phoneme-state posteriors as output +steps/train_nnet.sh +#steps/decode_nnet.sh #TODO + diff --git a/egs/rm/s2/scripts/add_disambig.pl b/egs/rm/s2/scripts/add_disambig.pl new file mode 100755 index 000000000..a37af62b3 --- /dev/null +++ b/egs/rm/s2/scripts/add_disambig.pl @@ -0,0 +1,58 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Adds some specified number of disambig symbols to a symbol table. +# Adds these as #1, #2, etc. +# If the --include-zero option is specified, includes an extra one +# #0. +if(!(@ARGV == 2 || (@ARGV ==3 && $ARGV[0] eq "--include-zero"))) { + die "Usage: add_disambig.pl [--include-zero] symtab.txt num_extra > symtab_out.txt "; +} + +if(@ARGV == 3) { + $include_zero = 1; + $ARGV[0] eq "--include-zero" || die "Bad option/first argument $ARGV[0]"; + shift @ARGV; +} else { + $include_zero = 0; +} + +$input = $ARGV[0]; +$nsyms = $ARGV[1]; + +open(F, "<$input") || die "Opening file $input"; + +while() { + @A = split(" ", $_); + @A == 2 || die "Bad line $_"; + $lastsym = $A[1]; + print; +} + +if(!defined($lastsym)){ + die "Empty symbol file?"; +} + +if($include_zero) { + $lastsym++; + print "#0 $lastsym\n"; +} + +for($n = 1; $n <= $nsyms; $n++) { + $y = $n + $lastsym; + print "#$n $y\n"; +} diff --git a/egs/rm/s2/scripts/add_lex_disambig.pl b/egs/rm/s2/scripts/add_lex_disambig.pl new file mode 100755 index 000000000..9f9054e17 --- /dev/null +++ b/egs/rm/s2/scripts/add_lex_disambig.pl @@ -0,0 +1,101 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Adds disambiguation symbols to a lexicon. +# Outputs still in the normal lexicon format. +# Disambig syms are numbered #1, #2, #3, etc. (#0 +# reserved for symbol in grammar). +# Outputs the number of disambig syms to the standard output. + +if(@ARGV != 2) { + die "Usage: add_lex_disambig.pl lexicon.txt lexicon_disambig.txt " +} + + +$lexfn = shift @ARGV; +$lexoutfn = shift @ARGV; + +open(L, "<$lexfn") || die "Error opening lexicon $lexfn"; + +# (1) Read in the lexicon. +@L = ( ); +while() { + @A = split(" ", $_); + push @L, join(" ", @A); +} + +# (2) Work out the count of each phone-sequence in the +# lexicon. + +foreach $l (@L) { + @A = split(" ", $l); + shift @A; # Remove word. + $count{join(" ",@A)}++; +} + +# (3) For each left sub-sequence of each phone-sequence, note down +# that exists (for identifying prefixes of longer strings). + +foreach $l (@L) { + @A = split(" ", $l); + shift @A; # Remove word. + while(@A > 0) { + pop @A; # Remove last phone + $issubseq{join(" ",@A)} = 1; + } +} + +# (4) For each entry in the lexicon: +# if the phone sequence is unique and is not a +# prefix of another word, no diambig symbol. +# Else output #1, or #2, #3, ... if the same phone-seq +# has already been assigned a disambig symbol. + + +open(O, ">$lexoutfn") || die "Opening lexicon file $lexoutfn for writing.\n"; + +$max_disambig = 0; +foreach $l (@L) { + @A = split(" ", $l); + $word = shift @A; + $phnseq = join(" ",@A); + if(!defined $issubseq{$phnseq} + && $count{$phnseq}==1) { + ; # Do nothing. + } else { + if($phnseq eq "") { # need disambig symbols for the empty string + # that are not use anywhere else. + $max_disambig++; + $reserved{$max_disambig} = 1; + $phnseq = "#$max_disambig"; + } else { + $curnumber = $disambig_of{$phnseq}; + if(!defined{$curnumber}) { $curnumber = 0; } + $curnumber++; # now 1 or 2, ... + while(defined $reserved{$curnumber} ) { $curnumber++; } # skip over reserved symbols + if($curnumber > $max_disambig) { + $max_disambig = $curnumber; + } + $disambig_of{$phnseq} = $curnumber; + $phnseq = $phnseq . " #" . $curnumber; + } + } + print O "$word\t$phnseq\n"; +} + +print $max_disambig . "\n"; + diff --git a/egs/rm/s2/scripts/filter_scp.pl b/egs/rm/s2/scripts/filter_scp.pl new file mode 100755 index 000000000..ac40838b7 --- /dev/null +++ b/egs/rm/s2/scripts/filter_scp.pl @@ -0,0 +1,40 @@ +#!/usr/bin/perl -w +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids and filters an scp +# file (or any file whose first field is an utterance id), printing +# out only those lines whose first field is in id_list. + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl id_list [in.scp] > out.scp "; +} + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + if($seen{$A[0]}) { + print $_; + } +} diff --git a/egs/rm/s2/scripts/gen_mlp_init.py b/egs/rm/s2/scripts/gen_mlp_init.py new file mode 100755 index 000000000..3a4fe2481 --- /dev/null +++ b/egs/rm/s2/scripts/gen_mlp_init.py @@ -0,0 +1,75 @@ +#!/usr/bin/python -u + +# ./gen_hamm_dct.py +# script generateing NN initialization for training with TNet +# +# author: Karel Vesely +# + +import math, random +import sys + + +from optparse import OptionParser + +parser = OptionParser() +parser.add_option('--dim', dest='dim', help='d1:d2:d3 layer dimensions in the network') +parser.add_option('--gauss', dest='gauss', help='use gaussian noise for weights', action='store_true', default=False) +parser.add_option('--negbias', dest='negbias', help='use uniform [-4.1,-3.9] for bias (defaultall 0.0)', action='store_true', default=False) +parser.add_option('--inputscale', dest='inputscale', help='scale the weights by 3/sqrt(Ninputs)', action='store_true', default=False) +parser.add_option('--linBNdim', dest='linBNdim', help='dim of linear bottleneck (sigmoids will be omitted, bias will be zero)',default=0) +(options, args) = parser.parse_args() + +if(options.dim == None): + parser.print_help() + sys.exit(1) + + +dimStrL = options.dim.split(':') + +dimL = [] +for i in range(len(dimStrL)): + dimL.append(int(dimStrL[i])) + + +#print dimL,'linBN',options.linBNdim + +for layer in range(len(dimL)-1): + print '', dimL[layer+1], dimL[layer] + #weight matrix + print '[' + for row in range(dimL[layer+1]): + for col in range(dimL[layer]): + if(options.gauss): + if(options.inputscale): + print 3/math.sqrt(dimL[layer])*random.gauss(0.0,1.0), + else: + print 0.1*random.gauss(0.0,1.0), + else: + if(options.inputscale): + print (random.random()-0.5)*2*3/math.sqrt(dimL[layer]), + else: + print random.random()/5.0-0.1, + print #newline for each row + print ']' + #bias vector + print '[', + for idx in range(dimL[layer+1]): + if(int(options.linBNdim) == dimL[layer+1]): + print '0.0', + elif(options.negbias): + print random.random()/5.0-4.1, + else: + print '0.0', + print ']' + + if(int(options.linBNdim) != dimL[layer+1]): + if(layer == len(dimL)-2): + print '', dimL[layer+1], dimL[layer+1] + else: + print '', dimL[layer+1], dimL[layer+1] + + + + + diff --git a/egs/rm/s2/scripts/int2sym.pl b/egs/rm/s2/scripts/int2sym.pl new file mode 100755 index 000000000..c91802bcc --- /dev/null +++ b/egs/rm/s2/scripts/int2sym.pl @@ -0,0 +1,69 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +$ignore_noninteger = 0; +$ignore_first_field = 0; +for($x = 0; $x < 2; $x++) { + if($ARGV[0] eq "--ignore-noninteger") { $ignore_oov = 1; shift @ARGV; } + if($ARGV[0] eq "--ignore-first-field") { $ignore_first_field = 1; shift @ARGV; } +} + +$symtab = shift @ARGV; +if(!defined $symtab) { + die "Usage: sym2int.pl symtab [input transcriptions] > output transcriptions\n"; +} +open(F, "<$symtab") || die "Error opening symbol table file $symtab"; +while() { + @A = split(" ", $_); + @A == 2 || die "bad line in symbol table file: $_"; + $int2sym{$A[1]} = $A[0]; +} + +$error = 0; +while(<>) { + @A = split(" ", $_); + if(@A == 0) { + die "Empty line in transcriptions input."; + } + if($ignore_first_field) { + $key = shift @A; + print $key . " "; + } + foreach $a (@A) { + if($a !~ m:^\d+$:) { # not all digits.. + if($ignore_noninteger) { + print $a . " "; + next; + } else { + if($a eq $A[0]) { + die "int2sym.pl: found noninteger token $a (try --ignore-first-field)\n"; + } else { + die "int2sym.pl: found noninteger token $a (try --ignore-noninteger if valid input)\n"; + } + } + } + $s = $int2sym{$a}; + if(!defined ($s)) { + die "int2sym.pl: integer $a not in symbol table $symtab."; + } + print $s . " "; + } + print "\n"; +} + + + diff --git a/egs/rm/s2/scripts/is_sorted.sh b/egs/rm/s2/scripts/is_sorted.sh new file mode 100755 index 000000000..ac6ae42e7 --- /dev/null +++ b/egs/rm/s2/scripts/is_sorted.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Usage: is_sorted.sh [script-file] +# This script returns 0 (success) if the script file argument [or standard input] +# is sorted and 1 otherwise. + +export LC_ALL=C + +if [ $# == 0 ]; then + scp=- +fi +if [ $# == 1 ]; then + scp=$1 +fi +if [ $# -gt 1 -o "$1" == "--help" -o "$1" == "-h" ]; then + echo "Usage: is_sorted.sh [script-file]" + exit 1 +fi + +cat $scp > /tmp/tmp1.$$ +sort /tmp/tmp1.$$ > /tmp/tmp2.$$ +cmp /tmp/tmp1.$$ /tmp/tmp2.$$ >/dev/null +ret=$? +rm /tmp/tmp1.$$ /tmp/tmp2.$$ +if [ $ret == 0 ]; then + exit 0; +else + echo "is_sorted.sh: script file $scp is not sorted"; + exit 1; +fi diff --git a/egs/rm/s2/scripts/make_lexicon_fst.pl b/egs/rm/s2/scripts/make_lexicon_fst.pl new file mode 100755 index 000000000..5293f4023 --- /dev/null +++ b/egs/rm/s2/scripts/make_lexicon_fst.pl @@ -0,0 +1,112 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# makes lexicon FST (no pron-probs involved). + +if(@ARGV != 1 && @ARGV != 3) { + die "Usage: make_lexicon_fst.pl lexicon.txt [silprob silphone] > lexiconfst.txt" +} + +$lexfn = shift @ARGV; +if(@ARGV == 0) { + $silprob = 0.0; +} else { + ($silprob,$silphone) = @ARGV; +} +if($silprob != 0.0) { + $silprob < 1.0 || die "Sil prob cannot be >= 1.0"; + $silcost = -log($silprob); + $nosilcost = -log(1.0 - $silprob); +} + + +open(L, "<$lexfn") || die "Error opening lexicon $lexfn"; + + + +if( $silprob == 0.0 ) { # No optional silences: just have one (loop+final) state which is numbered zero. + $loopstate = 0; + $nexststate = 1; # next unallocated state. + while() { + @A = split(" ", $_); + $w = shift @A; + if(@A == 0) { # For empty words ( and ) insert no optional + # silence (not needed as adjacent words supply it).... + # actually we only hit this case for the lexicon without disambig + # symbols but doesn't ever matter as training transcripts don't have or . + print "$loopstate\t$loopstate\t\t$w\n"; + } else { + $s = $loopstate; + $word_or_eps = $w; + while (@A > 0) { + $p = shift @A; + if(@A > 0) { + $ns = $nextstate++; + } else { + $ns = $loopstate; + } + print "$s\t$ns\t$p\t$word_or_eps\n"; + $word_or_eps = ""; + $s = $ns; + } + } + } + print "$loopstate\t0\n"; # final-cost. +} else { # have silence probs. + $startstate = 0; + $loopstate = 1; + $silstate = 2; # state from where we go to loopstate after emitting silence. + $nextstate = 3; + print "$startstate\t$loopstate\t\t\t$nosilcost\n"; # no silence. + print "$startstate\t$loopstate\t$silphone\t\t$silcost\n"; # silence. + print "$silstate\t$loopstate\t$silphone\t\n"; # no cost. + while() { + @A = split(" ", $_); + $w = shift @A; + if(@A == 0) { # For empty words ( and ) insert no optional + # silence (not needed as adjacent words supply it).... + # actually we only hit this case for the lexicon without disambig + # symbols but doesn't ever matter as training transcripts don't have or . + print "$loopstate\t$loopstate\t\t$w\n"; + } else { + $is_silence_word = (@A == 1 && $A[0] eq $silphone); # boolean. + $s = $loopstate; + $word_or_eps = $w; + while (@A > 0) { + $p = shift @A; + if(@A > 0) { + $ns = $nextstate++; + print "$s\t$ns\t$p\t$word_or_eps\n"; + $word_or_eps = ""; + $s = $ns; + } else { + if(! $is_silence_word) { + # This is non-deterministic but relatively compact, + # and avoids epsilons. + print "$s\t$loopstate\t$p\t$word_or_eps\t$nosilcost\n"; + print "$s\t$silstate\t$p\t$word_or_eps\t$silcost\n"; + } else { + # no point putting opt-sil after silence word. + print "$s\t$loopstate\t$p\t$word_or_eps\n"; + } + $word_or_eps = ""; + } + } + } + } + print "$loopstate\t0\n"; # final-cost. +} diff --git a/egs/rm/s2/scripts/make_phones_symtab.pl b/egs/rm/s2/scripts/make_phones_symtab.pl new file mode 100755 index 000000000..0a5776e76 --- /dev/null +++ b/egs/rm/s2/scripts/make_phones_symtab.pl @@ -0,0 +1,37 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# make_phones_symtab.pl < lexicon.txt > phones.txt + + +while(<>) { + @A = split(" ", $_); + for ($i=2; $i<@A; $i++) { + $P{$A[$i]} = 1; # seen it. + } +} + +print "\t0\n"; +$n = 1; +foreach $p (sort keys %P) { + if($p ne "") { + print "$p\t$n\n"; + $n++; + } +} + +print "sil\t$n\n"; + diff --git a/egs/rm/s2/scripts/make_rm_dict.pl b/egs/rm/s2/scripts/make_rm_dict.pl new file mode 100755 index 000000000..12d0a3363 --- /dev/null +++ b/egs/rm/s2/scripts/make_rm_dict.pl @@ -0,0 +1,130 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Yanmin Qian Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# This file takes as input the file pcdsril.txt that comes with the RM +# distribution, and creates the dictionary used in RM training. + +# make_rm_dct.pl pcdsril.txt > dct.txt + +if (@ARGV != 1) { + die "usage: make_rm_dct.pl pcdsril.txt > dct.txt\n"; +} +unless (open(IN_FILE, "@ARGV[0]")) { + die ("can't open @ARGV[0]"); +} + +while ($line = ) +{ + chop($line); + if (($line =~ /^[a-z]/)) + { + $line =~ s/\+1//g; + @LineArray = split(/\s+/,$line); + @LineArray[0] = uc(@LineArray[0]); + + printf "%-16s", @LineArray[0]; + for ($i = 1; $i < @LineArray; $i ++) + { + if (@LineArray[$i] eq 'q') + {} + elsif (@LineArray[$i] eq 'zh') + { + printf "sh "; + } + elsif (@LineArray[$i] eq 'eng') + { + printf "ng "; + } + elsif (@LineArray[$i] eq 'hv') + { + printf "hh "; + } + elsif (@LineArray[$i] eq 'em') + { + printf "m "; + } + elsif (@LineArray[$i] eq 'axr') + { + printf "er "; + } + elsif (@LineArray[$i] eq 'tcl') + { + if (@LineArray[$i+1] ne 't') + { + printf "td "; + } + } + elsif (@LineArray[$i] eq 'dcl') + { + if (@LineArray[$i+1] ne 'd') + { + printf "dd "; + } + } + elsif (@LineArray[$i] eq 'kcl') + { + if (@LineArray[$i+1] ne 'k') + { + printf "kd "; + } + } + elsif (@LineArray[$i] eq 'pcl') + { + if (@LineArray[$i+1] ne 'p') + { + printf "pd "; + } + } + elsif (@LineArray[$i] eq 'bcl') + { + if (@LineArray[$i+1] ne 'b') + { + printf "b "; + } + } + elsif (@LineArray[$i] eq 'gcl') + { + if (@LineArray[$i+1] ne 'g') + { + printf "g "; + } + } + elsif (@LineArray[$i] eq 't') + { + if (@LineArray[$i+1] ne 's') + { + printf "@LineArray[$i] "; + } + else + { + printf "ts "; + $i++; + } + } + else + { + printf "@LineArray[$i] "; + } + } + printf "\n"; + } +} + +printf "!SIL sil\n"; + +close(IN_FILE); + + diff --git a/egs/rm/s2/scripts/make_rm_lm.pl b/egs/rm/s2/scripts/make_rm_lm.pl new file mode 100755 index 000000000..c5af12d75 --- /dev/null +++ b/egs/rm/s2/scripts/make_rm_lm.pl @@ -0,0 +1,119 @@ +#!/usr/bin/perl + +# Copyright 2010-2011 Yanmin Qian Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# This file takes as input the file wp_gram.txt that comes with the RM +# distribution, and creates the language model as an acceptor in FST form. + +# make_rm_lm.pl wp_gram.txt > G.txt + +if (@ARGV != 1) { + print "usage: make_rm_lm.pl wp_gram.txt > G.txt\n"; + exit(0); +} +unless (open(IN_FILE, "@ARGV[0]")) { + die ("can't open @ARGV[0]"); +} + + +$flag = 0; +$count_wrd = 0; +$cnt_ends = 0; +$init = ""; + +while ($line = ) +{ + chop($line); + + $line =~ s/ //g; + + if(($line =~ /^>/)) + { + if($flag == 0) + { + $flag = 1; + } + $line =~ s/>//g; + $hashcnt{$init} = $i; + $init = $line; + $i = 0; + $count_wrd++; + @LineArray[$count_wrd - 1] = $init; + $hashwrd{$init} = 0; + } + elsif($flag != 0) + { + + $hash{$init}[$i] = $line; + $i++; + if($line =~ /SENTENCE-END/) + { + $cnt_ends++; + } + } + else + {} +} + +$hashcnt{$init} = $i; + +$num = 0; +$weight = 0; +$init_wrd = "SENTENCE-END"; +$hashwrd{$init_wrd} = @LineArray; +for($i = 0; $i < $hashcnt{$init_wrd}; $i++) +{ + $weight = -log(1/$hashcnt{$init_wrd}); + $hashwrd{$hash{$init_wrd}[$i]} = $i + 1; + print "0 $hashwrd{$hash{$init_wrd}[$i]} $hash{$init_wrd}[$i] $hash{$init_wrd}[$i] $weight\n"; +} +$num = $i; + +for($i = 0; $i < @LineArray; $i++) +{ + if(@LineArray[$i] eq 'SENTENCE-END') + {} + else + { + if($hashwrd{@LineArray[$i]} == 0) + { + $num++; + $hashwrd{@LineArray[$i]} = $num; + } + for($j = 0; $j < $hashcnt{@LineArray[$i]}; $j++) + { + $weight = -log(1/$hashcnt{@LineArray[$i]}); + if($hashwrd{$hash{@LineArray[$i]}[$j]} == 0) + { + $num++; + $hashwrd{$hash{@LineArray[$i]}[$j]} = $num; + } + if($hash{@LineArray[$i]}[$j] eq 'SENTENCE-END') + { + print "$hashwrd{@LineArray[$i]} $hashwrd{$hash{@LineArray[$i]}[$j]} $weight\n" + } + else + { + print "$hashwrd{@LineArray[$i]} $hashwrd{$hash{@LineArray[$i]}[$j]} $hash{@LineArray[$i]}[$j] $hash{@LineArray[$i]}[$j] $weight\n"; + } + } + } +} + +print "$hashwrd{$init_wrd} 0\n"; +close(IN_FILE); + + diff --git a/egs/rm/s2/scripts/make_roots.pl b/egs/rm/s2/scripts/make_roots.pl new file mode 100755 index 000000000..eeed0e6e3 --- /dev/null +++ b/egs/rm/s2/scripts/make_roots.pl @@ -0,0 +1,102 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Written by Dan Povey 9/21/2010. Apache 2.0 License. + +# This version of make_roots.pl is specialized for RM. + +# This script creates the file roots.txt which is an input to train-tree.cc. It +# specifies how the trees are built. The input file phone-sets.txt is a partial +# version of roots.txt in which phones are represented by their spelled form, not +# their symbol id's. E.g. at input, phone-sets.txt might contain; +# shared not-split sil +# Any phones not specified in phone-sets.txt but present in phones.txt will +# be given a default treatment. If the --separate option is given, we create +# a separate tree root for each of them, otherwise they are all lumped in one set. +# The arguments shared|not-shared and split|not-split are needed if any +# phones are not specified in phone-sets.txt. What they mean is as follows: +# if shared=="shared" then we share the tree-root between different HMM-positions +# (0,1,2). If split=="split" then we actually do decision tree splitting on +# that root, otherwise we forbid decision-tree splitting. (The main reason we might +# set this to false is for silence when +# we want to ensure that the HMM-positions will remain with a single PDF id. + + +$separate = 0; +if($ARGV[0] eq "--separate") { + $separate = 1; + shift @ARGV; +} + +if(@ARGV != 4) { + die "Usage: make_roots.pl [--separate] phones.txt silence-phone-list[integer,colon-separated] shared|not-shared split|not-split > roots.txt\n"; +} + + +($phonesfile, $silphones, $shared, $split) = @ARGV; +if($shared ne "shared" && $shared ne "not-shared") { + die "Third argument must be \"shared\" or \"not-shared\"\n"; +} +if($split ne "split" && $split ne "not-split") { + die "Third argument must be \"split\" or \"not-split\"\n"; +} + + + +open(F, "<$phonesfile") || die "Opening file $phonesfile"; + +while() { + @A = split(" ", $_); + if(@A != 2) { + die "Bad line in phones symbol file: ".$_; + } + if($A[1] != 0) { + $symbol2id{$A[0]} = $A[1]; + $id2symbol{$A[1]} = $A[0]; + } +} + +if($silphones == ""){ + die "Empty silence phone list in make_roots.pl"; +} +foreach $silphoneid (split(":", $silphones)) { + defined $id2symbol{$silphoneid} || die "No such silence phone id $silphoneid"; + # Give each silence phone its own separate pdfs in each state, but + # no sharing (in this recipe; WSJ is different.. in this recipe there + #is only one silence phone anyway.) + $issil{$silphoneid} = 1; + print "not-shared not-split $silphoneid\n"; +} + +$idlist = ""; +$remaining_phones = ""; + +if($separate){ + foreach $a (keys %id2symbol) { + if(!defined $issil{$a}) { + print "$shared $split $a\n"; + } + } +} else { + print "$shared $split "; + foreach $a (keys %id2symbol) { + if(!defined $issil{$a}) { + print "$a "; + } + } + print "\n"; +} diff --git a/egs/rm/s2/scripts/make_words_symtab.pl b/egs/rm/s2/scripts/make_words_symtab.pl new file mode 100755 index 000000000..7e8899b0e --- /dev/null +++ b/egs/rm/s2/scripts/make_words_symtab.pl @@ -0,0 +1,39 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# make_words_symtab.pl < G.txt > words.txt + + + + +while(<>) { + @A = split(" ", $_); + if(@A >= 3) { + $W{$A[2]} = 1; + } +} + +print "\t0\n"; +$n = 1; +foreach $w (sort keys %W) { + if($w ne "") { + print "$w\t$n\n"; + $n++; + } +} + +print "!SIL\t$n\n"; + diff --git a/egs/rm/s2/scripts/mkgraph.sh b/egs/rm/s2/scripts/mkgraph.sh new file mode 100755 index 000000000..a6432bcf2 --- /dev/null +++ b/egs/rm/s2/scripts/mkgraph.sh @@ -0,0 +1,107 @@ +#!/bin/bash +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +reorder=true # Dan-style, make false for Mirko+Lukas's decoder. + +for x in 1 2 3; do + if [ $1 == "--mono" ]; then + monophone_opts="--context-size=1 --central-position=0" + shift; + fi + + if [ $1 == "--noreorder" ]; then + reorder=false # we set this for the Kaldi decoder. + shift; + fi +done + +if [ $# != 3 ]; then + echo "Usage: scripts/mkgraph.sh " + exit 1; +fi + +if [ -f path.sh ]; then . path.sh; fi + +tree=$1 +model=$2 +dir=$3 + +mkdir -p $dir + +tscale=1.0 +loopscale=0.1 + +fsttablecompose data/L_disambig.fst data/G.fst | fstdeterminizestar --use-log=true | \ + fstminimizeencoded > $dir/LG.fst + +fstisstochastic $dir/LG.fst || echo "warning: LG not stochastic." + +echo "Example string from LG.fst: " +echo +fstrandgen --select=log_prob $dir/LG.fst | fstprint --isymbols=data/phones_disambig.txt --osymbols=data/words.txt - + +grep '#' data/phones_disambig.txt | awk '{print $2}' > $dir/disambig_phones.list + +fstcomposecontext $monophone_opts \ + --read-disambig-syms=$dir/disambig_phones.list \ + --write-disambig-syms=$dir/disambig_ilabels.list \ + $dir/ilabels < $dir/LG.fst >$dir/CLG.fst + + # for debugging: + fstmakecontextsyms data/phones.txt $dir/ilabels > $dir/context_syms.txt + echo "Example string from CLG.fst: " + echo + fstrandgen --select=log_prob $dir/CLG.fst | fstprint --isymbols=$dir/context_syms.txt --osymbols=data/words.txt - + +fstisstochastic $dir/CLG.fst || echo "warning: CLG not stochastic." + +make-ilabel-transducer --write-disambig-syms=$dir/disambig_ilabels_remapped.list $dir/ilabels $tree $model $dir/ilabels.remapped > $dir/ilabel_map.fst + +# Reduce size of CLG by remapping symbols... +fsttablecompose $dir/ilabel_map.fst $dir/CLG.fst | fstdeterminizestar --use-log=true \ + | fstminimizeencoded > $dir/CLG2.fst + + +cat $dir/CLG2.fst | fstisstochastic || echo "warning: CLG2 is not stochastic." + +make-h-transducer --disambig-syms-out=$dir/disambig_tstate.list \ + --transition-scale=$tscale $dir/ilabels.remapped $tree $model > $dir/Ha.fst + + +fsttablecompose $dir/Ha.fst $dir/CLG2.fst | fstdeterminizestar --use-log=true \ + | fstrmsymbols $dir/disambig_tstate.list | fstrmepslocal | fstminimizeencoded > $dir/HCLGa.fst + +fstisstochastic $dir/HCLGa.fst || echo "HCLGa is not stochastic" + +add-self-loops --self-loop-scale=$loopscale --reorder=$reorder $model < $dir/HCLGa.fst > $dir/HCLG.fst + +if [ $tscale == 1.0 -a $loopscale == 1.0 ]; then + # No point doing this test if transition-scale not 1, as it is bound to fail. + fstisstochastic $dir/HCLG.fst || echo "Final HCLG is not stochastic." +fi + +fstisstochastic $dir/HCLG.fst || echo "Final HCLG is not stochastic." + + +#The next five lines are debug. +# The last two lines of this block print out some alignment info. +fstrandgen --select=log_prob $dir/HCLG.fst | fstprint --osymbols=data/words.txt > $dir/rand.txt +cat $dir/rand.txt | awk 'BEGIN{printf("0 ");} {if(NF>=3 && $3 != 0){ printf ("%d ",$3); }} END {print ""; }' > $dir/rand_align.txt + +show-alignments data/phones.txt $model ark:$dir/rand_align.txt +cat $dir/rand.txt | awk ' {if(NF>=4 && $4 != ""){ printf ("%s ",$4); }} END {print ""; }' + diff --git a/egs/rm/s2/scripts/mkgraph_alt.sh b/egs/rm/s2/scripts/mkgraph_alt.sh new file mode 100755 index 000000000..2cfc92692 --- /dev/null +++ b/egs/rm/s2/scripts/mkgraph_alt.sh @@ -0,0 +1,115 @@ +#!/bin/bash +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This version of mkgraph.sh creates the C fst explicitly. + +reorder=true # Dan-style, make false for Mirko+Lukas's decoder. + +for x in 1 2 3; do + if [ $1 == "--mono" ]; then + monophone_opts="--context-size=1 --central-position=0" + shift; + fi + + if [ $1 == "--noreorder" ]; then + reorder=false # we set this for the Kaldi decoder. + shift; + fi +done + +if [ $# != 3 ]; then + echo "Usage: scripts/mkgraph.sh " + exit 1; +fi + +if [ -f path.sh ]; then . path.sh; fi + + +tree=$1 +model=$2 +dir=$3 + +mkdir -p $dir + +tscale=1.0 +loopscale=0.1 + +fsttablecompose data/L_disambig.fst data/G.fst | fstdeterminizestar --use-log=true | \ + fstminimizeencoded > $dir/LG.fst + +fstisstochastic $dir/LG.fst || echo "warning: LG not stochastic." + +echo "Example string from LG.fst: " +echo +fstrandgen --select=log_prob $dir/LG.fst | fstprint --isymbols=data/phones_disambig.txt --osymbols=data/words.txt - + +grep '#' data/phones_disambig.txt | awk '{print $2}' > $dir/disambig_phones.list +subseq_sym=`tail -1 data/phones_disambig.txt | awk '{print $2+1;}'` +cp data/phones_disambig.txt $dir/phones_disambig_subseq.txt +echo '$' $subseq_sym >> $dir/phones_disambig_subseq.txt + +fstmakecontextfst --read-disambig-syms=$dir/disambig_phones.list \ + --write-disambig-syms=$dir/disambig_ilabels.list data/phones.txt $subseq_sym \ + $dir/ilabels | fstarcsort --sort_type=olabel > $dir/C.fst + +fstaddsubsequentialloop $subseq_sym $dir/LG.fst | \ + fsttablecompose $dir/C.fst - > $dir/CLG.fst + + + # for debugging: + fstmakecontextsyms data/phones.txt $dir/ilabels > $dir/context_syms.txt + echo "Example string from CLG.fst: " + echo + fstrandgen --select=log_prob $dir/CLG.fst | fstprint --isymbols=$dir/context_syms.txt --osymbols=data/words.txt - + +fstisstochastic $dir/CLG.fst || echo "warning: CLG not stochastic." + +make-ilabel-transducer --write-disambig-syms=$dir/disambig_ilabels_remapped.list $dir/ilabels $tree $model $dir/ilabels.remapped > $dir/ilabel_map.fst + +# Reduce size of CLG by remapping symbols... +fstcompose $dir/ilabel_map.fst $dir/CLG.fst | fstdeterminizestar --use-log=true \ + | fstminimizeencoded > $dir/CLG2.fst + + +cat $dir/CLG2.fst | fstisstochastic || echo "warning: CLG2 is not stochastic." + +make-h-transducer --disambig-syms-out=$dir/disambig_tstate.list \ + --transition-scale=$tscale $dir/ilabels.remapped $tree $model > $dir/Ha.fst + + +fsttablecompose $dir/Ha.fst $dir/CLG2.fst | fstdeterminizestar --use-log=true \ + | fstrmsymbols $dir/disambig_tstate.list | fstrmepslocal | fstminimizeencoded > $dir/HCLGa.fst + +fstisstochastic $dir/HCLGa.fst || echo "HCLGa is not stochastic" + +add-self-loops --self-loop-scale=$loopscale --reorder=$reorder $model < $dir/HCLGa.fst > $dir/HCLG.fst + +if [ $tscale == 1.0 -a $loopscale == 1.0 ]; then + # No point doing this test if transition-scale not 1, as it is bound to fail. + fstisstochastic $dir/HCLG.fst || echo "Final HCLG is not stochastic." +fi + +fstisstochastic $dir/HCLG.fst || echo "Final HCLG is not stochastic." + + +#The next five lines are debug. +# The last two lines of this block print out some alignment info. +fstrandgen --select=log_prob $dir/HCLG.fst | fstprint --osymbols=data/words.txt > $dir/rand.txt +cat $dir/rand.txt | awk 'BEGIN{printf("0 ");} {if(NF>=3 && $3 != 0){ printf ("%d ",$3); }} END {print ""; }' > $dir/rand_align.txt +show-alignments data/phones.txt $model ark:$dir/rand_align.txt +cat $dir/rand.txt | awk ' {if(NF>=4 && $4 != ""){ printf ("%s ",$4); }} END {print ""; }' + diff --git a/egs/rm/s2/scripts/process_warps.pl b/egs/rm/s2/scripts/process_warps.pl new file mode 100755 index 000000000..c9710a0bc --- /dev/null +++ b/egs/rm/s2/scripts/process_warps.pl @@ -0,0 +1,47 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script is part of a diagnostic step when using exponential transforms. + +$map=$ARGV[0]; open(M,"<$map")||die "opening map file $map"; +while(){ @A=split(" ",$_); $map{$A[0]} = $A[1]; } +while(){ + ($spk,$warp)=split(" ",$_); + $class = int($class/2); + defined $map{$spk} || die "No gender info for speaker $spk"; + $warps{$map{$spk}} = $warps{$map{$spk}} . "$warp "; +} +@K = sort keys %warps; +@K==2||die "wrong number of keys [empty warps file?]"; +foreach $k ( @K ) { + $s = join(" ", sort { $a <=> $b } ( split(" ", $warps{$k}) )) ; + print "$k = [ $s ];\n"; +} +# f,m may be reversed below; doesnt matter. +foreach $w ( split(" ", $warps{$K[0]}) ) { + $nf += 1; $sumf += $w; $sumf2 += $w*$w; +} +foreach $w ( split(" ", $warps{$K[1]}) ) { + $nm += 1; $summ += $w; $summ2 += $w*$w; +} +$sumf /= $nf; $sumf2 /= $nf; +$summ /= $nm; $summ2 /= $nm; +$sumf2 -= $sumf*$sumf; +$summ2 -= $summ*$summ; +$avgwithin = 0.5*($sumf2+$summ2 ); +$diff = abs($sumf - $summ) / sqrt($avgwithin); +print "% class separation is $diff\n"; diff --git a/egs/rm/s2/scripts/silphones.pl b/egs/rm/s2/scripts/silphones.pl new file mode 100755 index 000000000..8cee6df94 --- /dev/null +++ b/egs/rm/s2/scripts/silphones.pl @@ -0,0 +1,57 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# creates integer lists of silence and non-silence phones in files, +# e.g. silphones.csl="1:2:3 \n" +# and nonsilphones.csl="4:5:6:7:...:24\n"; + +if(@ARGV != 4) { + die "Usage: silphones.pl phones.txt \"sil1 sil2 sil3\" silphones.csl nonsilphones.csl"; +} + +($symtab, $sillist, $silphones, $nonsilphones) = @ARGV; +open(S,"<$symtab") || die "Opening symbol table $symtab"; + + +foreach $s (split(" ", $sillist)) { + $issil{$s} = 1; +} + +@sil = (); +@nonsil = (); +while(){ + @A = split(" ", $_); + @A == 2 || die "Bad line $_ in phone-symbol-table file $symtab"; + ($sym, $int) = @A; + if($int != 0) { + if($issil{$sym}) { push @sil, $int; $seensil{$sym}=1; } + else { push @nonsil, $int; } + } +} + +foreach $k(keys %issil) { + if(!$seensil{$k}) { die "No such silence phone $k"; } +} +open(F, ">$silphones") || die "opening silphones file $silphones"; +open(G, ">$nonsilphones") || die "opening nonsilphones file $nonsilphones"; +print F join(":", @sil) . "\n"; +print G join(":", @nonsil) . "\n"; +close(F); +close(G); +if(@sil == 0) { print STDERR "Warning: silphones.pl no silence phones.\n" } +if(@nonsil == 0) { print STDERR "Warning: silphones.pl no non-silence phones.\n" } + diff --git a/egs/rm/s2/scripts/spk2utt_to_utt2spk.pl b/egs/rm/s2/scripts/spk2utt_to_utt2spk.pl new file mode 100755 index 000000000..ca8a6a124 --- /dev/null +++ b/egs/rm/s2/scripts/spk2utt_to_utt2spk.pl @@ -0,0 +1,27 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +while(<>){ + @A = split(" ", $_); + @A > 1 || die "Invalid line in spk2utt file: $_"; + $s = shift @A; + foreach $u ( @A ) { + print "$u $s\n"; + } +} + + diff --git a/egs/rm/s2/scripts/split_scp.pl b/egs/rm/s2/scripts/split_scp.pl new file mode 100755 index 000000000..f30d217a6 --- /dev/null +++ b/egs/rm/s2/scripts/split_scp.pl @@ -0,0 +1,181 @@ +#!/usr/bin/perl -w +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + + +# This program splits up any kind of .scp or archive-type file. +# If there is no utt2spk option it will work on any text file and +# will split it up with an approximately equal number of lines in +# each but. +# With the --utt2spk option it will work on anything that has the +# utterance-id as the first entry on each line; the utt2spk file is +# of the form "utterance speaker" (on each line). +# It splits it into equal size chunks as far as it can. If you use +# the utt2spk option it will make sure these chunks coincide with +# speaker boundaries. In this case, if there are more chunks +# than speakers (and in some other circumstances), some of the +# resulting chunks will be empty and it +# will print a warning. +# You will normally call this like: +# split_scp.pl scp scp.1 scp.2 scp.3 ... +# or +# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ... +# Note that you can use this script to split the utt2spk file itself, +# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ... + +if(@ARGV < 2 ) { + die "Usage: split_scp.pl [--utt2spk=] in.scp out1.scp out2.scp ... "; +} + +if($ARGV[0] =~ m:^-:) { + # Everything inside this block + # corresponds to what we do when the --utt2spk option is used. + $opt = shift @ARGV; + @A = split("=", $opt); + if(@A != 2 || $A[0] ne "--utt2spk") { + die "split_scp.pl: invalid option $ARGV[0]"; + } + $utt2spk_file = $A[1]; + open(U, "<$utt2spk_file") || die "Failed to open utt2spk file $utt2spk_file"; + while() { + @A = split; + @A == 2 || die "Bad line $_ in utt2spk file $utt2spk_file"; + ($u,$s) = @A; + $utt2spk{$u} = $s; + } + $inscp = shift @ARGV; + open(I, "<$inscp") || die "Opening input scp file $inscp"; + @spkrs = (); + while() { + @A = split; + if(@A == 0) { die "Empty or space-only line in scp file $inscp"; } + $u = $A[0]; + $s = $utt2spk{$u}; + if(!defined $s) { die "No such utterance $u in utt2spk file $utt2spk_file"; } + if(!defined $spk_count{$s}) { + push @spkrs, $s; + $spk_count{$s} = 0; + $spk_data{$s} = ""; + } + $spk_count{$s}++; + $spk_data{$s} = $spk_data{$s} . $_; + } + # Now split as equally as possible .. + # First allocate spks to files by given approximately + # equal #spks. + $numspks = @spkrs; # number of speakers. + $numscps = @ARGV; # number of output files. + $spksperscp = int( ($numspks+($numscps-1)) / $numscps); # the +$(numscps-1) forces rounding up. + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + $scparray[$scpidx] = []; # [] is array reference. + for($n = $spksperscp * $scpidx; + $n < $numspks && $n < $spksperscp*($scpidx+1); + $n++) { + $spk = $spkrs[$n]; + push @{$scparray[$scpidx]}, $spk; + $scpcount[$scpidx] += $spk_count{$spk}; + } + } + # Now will try to reassign beginning + ending speakers + # to different scp's and see if it gets more balanced. + # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2. + # We can show that if considering changing just 2 scp's, we minimize + # this by minimizing the squared difference in sizes. This is + # equivalent to minimizing the absolute difference in sizes. This + # shows this method is bound to converge. + + $changed = 1; + while($changed) { + $changed = 0; + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + # First try to reassign ending spk of this scp. + if($scpidx < $numscps-1) { + $sz = @{$scparray[$scpidx]}; + if($sz > 0) { + $spk = $scparray[$scpidx]->[$sz-1]; + $count = $spk_count{$spk}; + $nutt1 = $scpcount[$scpidx]; + $nutt2 = $scpcount[$scpidx+1]; + if( abs( ($nutt2+$count) - ($nutt1-$count)) + < abs($nutt2 - $nutt1)) { # Would decrease + # size-diff by reassigning spk... + $scpcount[$scpidx+1] += $count; + $scpcount[$scpidx] -= $count; + pop @{$scparray[$scpidx]}; + unshift @{$scparray[$scpidx+1]}, $spk; + $changed = 1; + } + } + } + if($scpidx > 0 && @{$scparray[$scpidx]} > 0) { + $spk = $scparray[$scpidx]->[0]; + $count = $spk_count{$spk}; + $nutt1 = $scpcount[$scpidx-1]; + $nutt2 = $scpcount[$scpidx]; + if( abs( ($nutt2-$count) - ($nutt1+$count)) + < abs($nutt2 - $nutt1)) { # Would decrease + # size-diff by reassigning spk... + $scpcount[$scpidx-1] += $count; + $scpcount[$scpidx] -= $count; + shift @{$scparray[$scpidx]}; + push @{$scparray[$scpidx-1]}, $spk; + $changed = 1; + } + } + } + } + # Now print out the files... + for($scpidx = 0; $scpidx < $numscps; $scpidx++) { + $scpfn = $ARGV[$scpidx]; + open(F, ">$scpfn") || die "Could not open scp file $scpfn for writing."; + $count = 0; + if(@{$scparray[$scpidx]} == 0) { + print STDERR "Warning: split_scp.pl producing empty .scp file $scpfn (too many splits and too few speakers?)"; + } + foreach $spk ( @{$scparray[$scpidx]} ) { + print F $spk_data{$spk}; + $count += $spk_count{$spk}; + } + if($count != $scpcount[$scpidx]) { die "Count mismatch [code error]"; } + close(F); + } +} else { + # This block is the "normal" case where there is no --utt2spk + # option and we just break into equal size chunks. + + $inscp = shift @ARGV; + open(I, "<$inscp") || die "Opening input scp file $inscp"; + + $numscps = @ARGV; # size of array. + @F = (); + while() { + push @F, $_; + } + $numlines = @F; + if($numlines == 0) { + print STDERR "split_scp.pl: warning: empty input scp file $inscp"; + } + $linesperscp = int( ($numlines+($numscps-1)) / $numscps); # the +$(numscps-1) forces rounding up. +# [just doing int() rounds down]. + for($scpidx = 0; $scpidx < @ARGV; $scpidx++) { + $scpfile = $ARGV[$scpidx]; + open(O, ">$scpfile") || die "Opening output scp file $scpfile"; + for($n = $linesperscp * $scpidx; $n < $numlines && $n < $linesperscp*($scpidx+1); $n++) { + print O $F[$n]; + } + close(O) || die "Closing scp file $scpfile"; + } +} diff --git a/egs/rm/s2/scripts/subset_scp.pl b/egs/rm/s2/scripts/subset_scp.pl new file mode 100755 index 000000000..e23b38788 --- /dev/null +++ b/egs/rm/s2/scripts/subset_scp.pl @@ -0,0 +1,59 @@ +#!/usr/bin/perl -w +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This program selects a subset of N elements in the scp. +# It selects them evenly from throughout the scp, in order to +# avoid selecting too many from the same speaker. +# It prints them on the standard output. + +if(@ARGV < 2 ) { + die "Usage: subset_scp.pl N in.scp "; +} + +$N = shift @ARGV; +if($N == 0) { + die "First command-line parameter to subset_scp.pl must be an integer, got \"$N\""; +} +$inscp = shift @ARGV; +open(I, "<$inscp") || die "Opening input scp file $inscp"; + +@F = (); +while() { + push @F, $_; +} +$numlines = @F; +if($N > $numlines) { + die "You requested from subset_scp.pl more elements than available: $N > $numlines"; +} + +sub select_n { + my ($start,$end,$num_needed) = @_; + my $diff = $end - $start; + if($num_needed > $diff) { die "select_n: code error"; } + if($diff == 1 ) { + if($num_needed > 0) { + print $F[$start]; + } + } else { + my $halfdiff = int($diff/2); + my $halfneeded = int($num_needed/2); + select_n($start, $start+$halfdiff, $halfneeded); + select_n($start+$halfdiff, $end, $num_needed - $halfneeded); + } +} +select_n(0, $numlines, $N); + diff --git a/egs/rm/s2/scripts/sym2int.pl b/egs/rm/s2/scripts/sym2int.pl new file mode 100755 index 000000000..4f8b218a7 --- /dev/null +++ b/egs/rm/s2/scripts/sym2int.pl @@ -0,0 +1,59 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +$ignore_oov = 0; +$ignore_first_field = 0; +for($x = 0; $x < 2; $x++) { + if($ARGV[0] eq "--ignore-oov") { $ignore_oov = 1; shift @ARGV; } + if($ARGV[0] eq "--ignore-first-field") { $ignore_first_field = 1; shift @ARGV; } +} + +$symtab = shift @ARGV; +if(!defined $symtab) { + die "Usage: sym2int.pl symtab [input transcriptions] > output transcriptions\n"; +} +open(F, "<$symtab") || die "Error opening symbol table file $symtab"; +while() { + @A = split(" ", $_); + @A == 2 || die "bad line in symbol table file: $_"; + $sym2int{$A[0]} = $A[1] + 0; +} + +while(<>) { + @A = split(" ", $_); + if(@A == 0) { + die "Empty line in transcriptions input."; + } + if($ignore_first_field) { + $key = shift @A; + print $key . " "; + } + foreach $a (@A) { + $i = $sym2int{$a}; + if(!defined ($i)) { + if($ignore_oov) { + print $a . " " ; + } else { + die "sym2int.pl: undefined symbol $a\n"; + } + } + print $i . " "; + } + print "\n"; +} + + diff --git a/egs/rm/s2/scripts/utt2spk_to_spk2utt.pl b/egs/rm/s2/scripts/utt2spk_to_spk2utt.pl new file mode 100755 index 000000000..c94eb8d53 --- /dev/null +++ b/egs/rm/s2/scripts/utt2spk_to_spk2utt.pl @@ -0,0 +1,33 @@ +#!/usr/bin/perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + + +while(<>){ + @A = split(" ", $_); + @A == 2 || die "Invalid line in utt2spk file: $_"; + ($u,$s) = @A; + if(!$seen_spk{$s}) { + $seen_spk{$s} = 1; + push @spklist, $s; + } + $uttlist{$s} = $uttlist{$s} . "$u "; +} +foreach $s (@spklist) { + $l = $uttlist{$s}; + $l =~ s: $::; # remove trailing space. + print "$s $l\n"; +} diff --git a/egs/rm/s2/steps/decode_mono.sh b/egs/rm/s2/steps/decode_mono.sh new file mode 100755 index 000000000..92075ae0b --- /dev/null +++ b/egs/rm/s2/steps/decode_mono.sh @@ -0,0 +1,45 @@ + +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# Monophone decoding script. + +if [ -f path.sh ]; then . path.sh; fi +dir=exp/decode_mono +tree=exp/mono/tree +mkdir -p $dir +model=exp/mono/final.mdl +graphdir=exp/graph_mono + +scripts/mkgraph.sh --mono $tree $model $graphdir + +for test in mar87 oct87 feb89 oct89 feb91 sep92; do + ( + feats="ark:add-deltas --print-args=false scp:data/test_${test}.scp ark:- |" + + gmm-decode-faster --beam=20.0 --acoustic-scale=0.08333 --word-symbol-table=data/words.txt $model $graphdir/HCLG.fst "$feats" ark,t:$dir/test_${test}.tra ark,t:$dir/test_${test}.ali 2> $dir/decode_${test}.log + + # the ,p option lets it score partial output without dying.. + scripts/sym2int.pl --ignore-first-field data/words.txt data_prep/test_${test}_trans.txt | \ + compute-wer --mode=present ark:- ark,p:$dir/test_${test}.tra >& $dir/wer_${test} + ) & +done + +wait + +grep WER $dir/wer_* | \ + awk '{n=n+$4; d=d+$6} END{ printf("Average WER is %f (%d / %d) \n", (100.0*n)/d, n, d); }' \ + > $dir/wer + diff --git a/egs/rm/s2/steps/make_mfcc_test.sh b/egs/rm/s2/steps/make_mfcc_test.sh new file mode 100755 index 000000000..df109dc8b --- /dev/null +++ b/egs/rm/s2/steps/make_mfcc_test.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# To be run from .. (one directory up from here) + +if [ $# != 1 ]; then + echo "usage: make_mfcc_test.sh " + exit 1; +fi + +if [ -f path.sh ]; then . path.sh; fi + +dir=exp/make_mfcc +mkdir -p $dir +root_out=$1 +mkdir -p $root_out + +for test in mar87 oct87 feb89 oct89 feb91 sep92; do + scpin=data_prep/test_${test}_wav.scp +# Making it like this so it works for others on the BUT filesystem. +# It will generate the correct scp file without running the feature extraction. + log=$dir/make_mfcc_test_${test}.log + ( + compute-mfcc-feats --verbose=2 --config=conf/mfcc.conf scp:$scpin ark,scp:$root_out/test_${test}_raw_mfcc.ark,$root_out/test_${test}_raw_mfcc.scp 2> $log || tail $log + cp $root_out/test_${test}_raw_mfcc.scp data/test_${test}.scp + ) & +done + +wait + +echo "If the above produced no output on the screen, it succeeded." diff --git a/egs/rm/s2/steps/make_mfcc_train.sh b/egs/rm/s2/steps/make_mfcc_train.sh new file mode 100755 index 000000000..b910a23aa --- /dev/null +++ b/egs/rm/s2/steps/make_mfcc_train.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# To be run from .. (one directory up from here) + +if [ $# != 1 ]; then + echo "usage: make_mfcc_train.sh "; + exit 1; +fi + +if [ -f path.sh ]; then . path.sh; fi + +scpin=data_prep/train_wav.scp +dir=exp/make_mfcc +mkdir -p $dir +root_out=$1 +mkdir -p $root_out + +scripts/split_scp.pl $scpin $dir/train_wav{1,2,3,4}.scp + +for n in 1 2 3 4; do # Use 4 CPUs + log=$dir/make_mfcc_train.$n.log + compute-mfcc-feats --verbose=2 --config=conf/mfcc.conf scp:$dir/train_wav${n}.scp ark,scp:$root_out/train_raw_mfcc${n}.ark,$root_out/train_raw_mfcc${n}.scp 2> $log || tail $log & +done + +wait; + +cat $root_out/train_raw_mfcc{1,2,3,4}.scp > data/train.scp + +echo "If the above produced no output on the screen, it succeeded." diff --git a/egs/rm/s2/steps/prepare_graphs.sh b/egs/rm/s2/steps/prepare_graphs.sh new file mode 100755 index 000000000..a0a8cde88 --- /dev/null +++ b/egs/rm/s2/steps/prepare_graphs.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# The output of this script is the symbol tables data/{words.txt,phones.txt}, +# and the grammars and lexicons data/{L,G}{,_disambig}.fst + +# To be run from .. +if [ -f path.sh ]; then . path.sh; fi + +cp data_prep/G.txt data/ +scripts/make_words_symtab.pl < data/G.txt > data/words.txt +cp data_prep/lexicon.txt data/ + + +scripts/make_phones_symtab.pl < data/lexicon.txt > data/phones.txt + +silphones="sil"; # This would in general be a space-separated list of all silence phones. E.g. "sil vn" +# Generate colon-separated lists of silence and non-silence phones. +scripts/silphones.pl data/phones.txt "$silphones" data/silphones.csl data/nonsilphones.csl + +ndisambig=`scripts/add_lex_disambig.pl data/lexicon.txt data/lexicon_disambig.txt` +scripts/add_disambig.pl data/phones.txt $ndisambig > data/phones_disambig.txt + + +# Create train transcripts in integer format: +cat data_prep/train_trans.txt | \ + scripts/sym2int.pl --ignore-first-field data/words.txt > data/train.tra + + +# Get lexicon in FST format. + +# silprob = 0.5: same prob as word. +scripts/make_lexicon_fst.pl data/lexicon.txt 0.5 sil | fstcompile --isymbols=data/phones.txt --osymbols=data/words.txt --keep_isymbols=false --keep_osymbols=false | fstarcsort --sort_type=olabel > data/L.fst + +scripts/make_lexicon_fst.pl data/lexicon_disambig.txt 0.5 sil | fstcompile --isymbols=data/phones_disambig.txt --osymbols=data/words.txt --keep_isymbols=false --keep_osymbols=false | fstarcsort --sort_type=olabel > data/L_disambig.fst + +fstcompile --isymbols=data/words.txt --osymbols=data/words.txt --keep_isymbols=false --keep_osymbols=false data/G.txt > data/G.fst + +# Checking that G is stochastic [note, it wouldn't be for an Arpa] +fstisstochastic data/G.fst || echo Error + + +# Checking that disambiguated lexicon times G is determinizable +fsttablecompose data/L_disambig.fst data/G.fst | fstdeterminize >/dev/null || echo Error + +# Checking that LG is stochastic: +fsttablecompose data/L.fst data/G.fst | fstisstochastic || echo Error + +## Check lexicon. +## just have a look and make sure it seems sane. +fstprint --isymbols=data/phones.txt --osymbols=data/words.txt data/L.fst | head + diff --git a/egs/rm/s2/steps/train_mono.sh b/egs/rm/s2/steps/train_mono.sh new file mode 100755 index 000000000..c0d45d104 --- /dev/null +++ b/egs/rm/s2/steps/train_mono.sh @@ -0,0 +1,92 @@ +#!/bin/bash +# Copyright 2010-2011 Microsoft Corporation Arnab Ghoshal + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# To be run from .. +if [ -f path.sh ]; then . path.sh; fi + +# Train the monophone on a subset-- no point using all the data. +dir=exp/mono +n=1000 +feats="ark:add-deltas --print-args=false scp:$dir/train.scp ark:- |" +# need to quote when passing as an argument, as in "$feats", +# since it has spaces in it. +scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1" + +numiters=30 # Number of iterations of training +maxiterinc=20 # Last iter to increase #Gauss on. +numgauss=250 # Initial num-Gauss (must be more than #states=3*phones). +totgauss=1000 # Target #Gaussians. +incgauss=$[($totgauss-$numgauss)/$maxiterinc] # per-iter increment for #Gauss +realign_iters="1 2 3 4 5 6 7 8 9 10 12 15 20 25"; + + +mkdir -p $dir +scripts/subset_scp.pl $n data/train.scp > $dir/train.scp + + +silphones=`cat data/silphones.csl | sed 's/:/ /g'` +nonsilphones=`cat data/nonsilphones.csl | sed 's/:/ /g'` +cat conf/topo.proto | sed "s:NONSILENCEPHONES:$nonsilphones:" | sed "s:SILENCEPHONES:$silphones:" > $dir/topo + +gmm-init-mono '--train-feats=ark:head -10 data/train.scp | add-deltas scp:- ark:- |' $dir/topo 39 $dir/0.mdl $dir/tree 2> $dir/init.out || exit 1; + + +echo "Compiling training graphs" +compile-train-graphs $dir/tree $dir/0.mdl data/L.fst \ + "ark:scripts/subset_scp.pl $n data/train.tra|" \ + "ark:|gzip -c >$dir/graphs.fsts.gz" 2>$dir/compile_graphs.log || exit 1 + +echo Pass 0 + +align-equal-compiled "ark:gunzip -c $dir/graphs.fsts.gz|" "$feats" \ + ark,t,f:- 2>$dir/align.0.log | \ + gmm-acc-stats-ali --binary=true $dir/0.mdl "$feats" ark:- \ + $dir/0.acc 2> $dir/acc.0.log || exit 1; + +# In the following steps, the --min-gaussian-occupancy=3 option is important, otherwise +# we fail to est "rare" phones and later on, they never align properly. +gmm-est --min-gaussian-occupancy=3 --mix-up=$numgauss \ + $dir/0.mdl $dir/0.acc $dir/1.mdl 2> $dir/update.0.log || exit 1; + +rm $dir/0.acc + + +beam=4 # will change to 8 below after 1st pass +x=1 +while [ $x -lt $numiters ]; do + echo "Pass $x" + if echo $realign_iters | grep -w $x >/dev/null; then + echo "Aligning data" + gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$[$beam*4] $dir/$x.mdl \ + "ark:gunzip -c $dir/graphs.fsts.gz|" "$feats" t,ark:$dir/cur.ali \ + 2> $dir/align.$x.log || exit 1; + fi + gmm-acc-stats-ali --binary=false $dir/$x.mdl "$feats" ark:$dir/cur.ali $dir/$x.acc 2> $dir/acc.$x.log || exit 1; + gmm-est --mix-up=$numgauss $dir/$x.mdl $dir/$x.acc $dir/$[$x+1].mdl 2> $dir/update.$x.log || exit 1; + rm $dir/$x.mdl $dir/$x.acc + if [ $x -le $maxiterinc ]; then + numgauss=$[$numgauss+$incgauss]; + fi + beam=8 + x=$[$x+1] +done + +( cd $dir; rm final.mdl 2>/dev/null; ln -s $x.mdl final.mdl ) + +# example of showing the alignments: +# show-alignments data/phones.txt $dir/30.mdl ark:$dir/cur.ali | head -4 + diff --git a/egs/rm/s2/steps/train_nnet.sh b/egs/rm/s2/steps/train_nnet.sh new file mode 100755 index 000000000..943e6fb85 --- /dev/null +++ b/egs/rm/s2/steps/train_nnet.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +# To be run from .. +if [ -f path.sh ]; then . path.sh; fi + +dir=exp/nnet +mkdir -p $dir/{log,nnet} + +#use following features and alignments +cp exp/mono/train.scp exp/mono/cur.ali $dir +head -n 800 $dir/train.scp > $dir/train.scp.tr +tail -n 200 $dir/train.scp > $dir/train.scp.cv +feats="ark:add-deltas --print-args=false scp:$dir/train.scp ark:- |" +feats_tr="ark:add-deltas --print-args=false scp:$dir/train.scp.tr ark:- |" +feats_cv="ark:add-deltas --print-args=false scp:$dir/train.scp.cv ark:- |" +labels="ark:$dir/cur.ali" + +#compute per utterance CMVN +cmvn="ark:$dir/cmvn.ark" +compute-cmvn-stats "$feats" $cmvn +feats_tr="$feats_tr apply-cmvn --print-args=false --norm-vars=true $cmvn ark:- ark:- |" +feats_cv="$feats_cv apply-cmvn --print-args=false --norm-vars=true $cmvn ark:- ark:- |" + + +#initialize the nnet +mlp_init=$dir/nnet.init +scripts/gen_mlp_init.py --dim=39:512:300 --gauss --negbias > $mlp_init + +#global config for trainig +max_iters=20 +start_halving_inc=0.5 +end_halving_inc=0.1 +lrate=0.001 + + + +nnet-train-xent-hardlab-perutt --cross-validate=true $mlp_init "$feats_cv" "$labels" &> $dir/log/prerun.log +if [ $? != 0 ]; then cat $dir/log/prerun.log; exit 1; fi +acc=$(cat $dir/log/prerun.log | grep Xent | tail -n 1 | cut -d'[' -f 2 | cut -d'%' -f 1) +echo CROSSVAL PRERUN ACCURACY $acc + +mlp_best=$mlp_init +mlp_base=${mlp_init##*/}; mlp_base=${mlp_base%.*} +halving=0 +for iter in $(seq -w $max_iters); do + mlp_next=$dir/nnet/${mlp_base}_iter${iter} + nnet-train-xent-hardlab-perutt --learn-rate=$lrate $mlp_best "$feats_tr" "$labels" $mlp_next &> $dir/log/iter$iter.log + if [ $? != 0 ]; then cat $dir/log/iter$iter.log; exit 1; fi + tr_acc=$(cat $dir/log/iter$iter.log | grep Xent | tail -n 1 | cut -d'[' -f 2 | cut -d'%' -f 1) + echo TRAIN ITERATION $iter ACCURACY $tr_acc LRATE $lrate + nnet-train-xent-hardlab-perutt --cross-validate=true $mlp_next "$feats_cv" "$labels" 1>>$dir/log/iter$iter.log 2>>$dir/log/iter$iter.log + if [ $? != 0 ]; then cat $dir/log/iter$iter.log; exit 1; fi + + #accept or reject new parameters + acc_new=$(cat $dir/log/iter$iter.log | grep Xent | tail -n 1 | cut -d'[' -f 2 | cut -d'%' -f 1) + echo CROSSVAL ITERATION $iter ACCURACY $acc_new + acc_prev=$acc + if [ 1 == $(awk 'BEGIN{print('$acc_new' > '$acc')}') ]; then + acc=$acc_new + mlp_best=$dir/nnet/$mlp_base.iter${iter}_tr$(printf "%.5g" $tr_acc)_cv$(printf "%.5g" $acc_new) + mv $mlp_next $mlp_best + echo nnet $mlp_best accepted + else + mlp_reject=$dir/nnet/$mlp_base.iter${iter}_tr$(printf "%.5g" $tr_acc)_cv$(printf "%.5g" $acc_new) + mv $mlp_next $mlp_reject + echo nnet $mlp_reject rejected + fi + + #stopping criterion + if [[ 1 == $halving && 1 == $(awk 'BEGIN{print('$acc' < '$acc_prev'+'$end_halving_inc')}') ]]; then + echo finished, too small improvement $(awk 'BEGIN{print('$acc'-'$acc_prev')}') + break + fi + + #start annealing when improvement is low + if [ 1 == $(awk 'BEGIN{print('$acc' < '$acc_prev'+'$start_halving_inc')}') ]; then + halving=1 + fi + + #do annealing + if [ 1 == $halving ]; then + lrate=$(awk 'BEGIN{print('$lrate'*0.5)}') + fi +done + +if [ $mlp_best != $mlp_init ]; then + iter=$(echo $mlp_best | sed 's/^.*iter\([0-9][0-9]*\).*$/\1/') +fi +mlp_final=$dir/${mlp_base}_final_iter${iter:-0}_acc${acc} +cp $mlp_best $mlp_final +echo final network $mlp_final + diff --git a/src/Makefile b/src/Makefile index 7f88cb790..37d357950 100644 --- a/src/Makefile +++ b/src/Makefile @@ -7,7 +7,8 @@ SUBDIRS = base matrix util feat tree optimization gmm transform sgmm \ fstext hmm lm decoder \ - bin fstbin gmmbin fgmmbin sgmmbin featbin + bin fstbin gmmbin fgmmbin sgmmbin featbin \ + nnet nnetbin all: $(SUBDIRS) echo Done diff --git a/src/nnet/Makefile b/src/nnet/Makefile new file mode 100644 index 000000000..b8d63847b --- /dev/null +++ b/src/nnet/Makefile @@ -0,0 +1,47 @@ + + +all: + +include ../kaldi.mk + +TESTFILES = nnet-test + +OBJFILES = nnet-nnet.o nnet-component.o nnet-loss.o + +LIBFILE = kaldi-nnet.a + +all: $(LIBFILE) $(TESTFILES) + +$(LIBFILE): $(OBJFILES) + $(AR) -cru $(LIBFILE) $(OBJFILES) + $(RANLIB) $(LIBFILE) + + + +$(TESTFILES): $(LIBFILE) ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../util/kaldi-util.a + +# Rule below would expand to, e.g.: +# ../base/kaldi-base.a: +# make -c ../base kaldi-base.a +# -c option to make is same as changing directory. +%.a: + $(MAKE) -C ${@D} ${@F} + +clean: + rm *.o *.a $(TESTFILES) + +test: $(TESTFILES) + for x in $(TESTFILES); do ./$$x >&/dev/null || (echo "***test $$x failed***"; exit 1); done + echo Tests succeeded + +.valgrind: $(TESTFILES) + + +depend: + -$(CXX) -M $(CXXFLAGS) *.cc > .depend.mk + +# removing automatic making of "depend" as it's quite slow. +#.depend.mk: depend + +-include .depend.mk + diff --git a/src/nnet/nnet-activation.h b/src/nnet/nnet-activation.h new file mode 100644 index 000000000..05f45db65 --- /dev/null +++ b/src/nnet/nnet-activation.h @@ -0,0 +1,94 @@ +// nnet/nnet-activation.h + +// Copyright 2011 Karel Vesely + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_NNET_ACTIVATION_H +#define KALDI_NNET_ACTIVATION_H + +#include "nnet/nnet-component.h" +namespace kaldi { + +class Sigmoid : public Component { + public: + Sigmoid(MatrixIndexT dim_in, MatrixIndexT dim_out, Nnet* nnet) + : Component(dim_in, dim_out, nnet) + { } + ~Sigmoid() + { } + + ComponentType GetType() const { + return kSigmoid; + } + + void PropagateFnc(const Matrix& in, Matrix* out) { + //y = 1/(1+e^-x) + for(MatrixIndexT r=0; rNumRows(); r++) { + for(MatrixIndexT c=0; cNumCols(); c++) { + (*out)(r,c) = 1.0/(1.0+exp(-in(r,c))); + } + } + } + + void BackpropagateFnc(const Matrix& in_err, Matrix* out_err) { + //ey = y(1-y)ex + const Matrix& y = nnet_->PropagateBuffer()[nnet_->IndexOfLayer(*this)+1]; + + for(MatrixIndexT r=0; rNumRows(); r++) { + for(MatrixIndexT c=0; cNumCols(); c++) { + (*out_err)(r,c) = y(r,c)*(1.0-y(r,c))*in_err(r,c); + } + } + } +}; + + +class Softmax : public Component { + public: + Softmax(MatrixIndexT dim_in, MatrixIndexT dim_out, Nnet* nnet) + : Component(dim_in, dim_out, nnet) + { } + ~Softmax() + { } + + ComponentType GetType() const { + return kSoftmax; + } + + void PropagateFnc(const Matrix& in, Matrix* out) { + //y = e^x_j/sum_j(e^x_j) + out->CopyFromMat(in); + for(MatrixIndexT r=0; rNumRows(); r++) { + out->Row(r).ApplySoftMax(); + } + } + + void BackpropagateFnc(const Matrix& in_err, Matrix* out_err) { + //simply copy the error + //(ie. assume crossentropy error function, + // while in_err contains (net_output-target) : + // this is already derivative of the error with + // respect to activations of last layer neurons) + out_err->CopyFromMat(in_err); + } +}; + + + +} // namespace + +#endif + diff --git a/src/nnet/nnet-biasedlinearity.h b/src/nnet/nnet-biasedlinearity.h new file mode 100644 index 000000000..6a9b59185 --- /dev/null +++ b/src/nnet/nnet-biasedlinearity.h @@ -0,0 +1,127 @@ +// nnet/nnet-biasedlinearity.h + +// Copyright 2011 Karel Vesely + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_NNET_BIASEDLINEARITY_H +#define KALDI_NNET_BIASEDLINEARITY_H + + +#include "nnet/nnet-component.h" + +namespace kaldi { + +class BiasedLinearity : public UpdatableComponent { + public: + BiasedLinearity(MatrixIndexT dim_in, MatrixIndexT dim_out, Nnet* nnet) + : UpdatableComponent(dim_in, dim_out, nnet), + linearity_(dim_out,dim_in), bias_(dim_out), + linearity_corr_(dim_out,dim_in), bias_corr_(dim_out) + { } + ~BiasedLinearity() + { } + + ComponentType GetType() const { + return kBiasedLinearity; + } + + void ReadData(std::istream& is, bool binary) { + linearity_.Read(is,binary); + bias_.Read(is,binary); + + KALDI_ASSERT(linearity_.NumRows() == output_dim_); + KALDI_ASSERT(linearity_.NumCols() == input_dim_); + KALDI_ASSERT(bias_.Dim() == output_dim_); + } + + void WriteData(std::ostream& os, bool binary) const { + linearity_.Write(os,binary); + bias_.Write(os,binary); + } + + void PropagateFnc(const Matrix& in, Matrix* out) { + //precopy bias + for (MatrixIndexT i=0; iNumRows(); i++) { + out->CopyRowFromVec(bias_,i); + } + //multiply by weights^t + out->AddMatMat(1.0,in,kNoTrans,linearity_,kTrans,1.0); + } + + void BackpropagateFnc(const Matrix& in_err, Matrix* out_err) { + //multiply error by weights + out_err->AddMatMat(1.0,in_err,kNoTrans,linearity_,kNoTrans,0.0); + } + + + void Update(const Matrix& input, const Matrix& err) { + + //compute gradient + linearity_corr_.AddMatMat(1.0,err,kTrans,input,kNoTrans,momentum_); + bias_corr_.Scale(momentum_); + bias_corr_.AddRowSumMat(err); + //l2 regularization + if(l2_penalty_ != 0.0) { + linearity_.AddMat(-learn_rate_*l2_penalty_*input.NumRows(),linearity_); + } + //l1 regularization + if(l1_penalty_ != 0.0) { + BaseFloat l1 = learn_rate_*input.NumRows()*l1_penalty_; + for(MatrixIndexT r=0; r 0.0) ^ (before > 0.0)) { + linearity_(r,c) = 0.0; + linearity_corr_(r,c) = 0.0; + } else { + linearity_(r,c) -= l1sign; + } + } + } + } + //update + linearity_.AddMat(-learn_rate_,linearity_corr_); + bias_.AddVec(-learn_rate_,bias_corr_); + + /* + std::cout <<"I"<< input.Row(0); + std::cout <<"E"<< err.Row(0); + std::cout <<"CORL"<< linearity_corr_.Row(0); + std::cout <<"CORB"<< bias_corr_; + std::cout <<"L"<< linearity_.Row(0); + std::cout <<"B"<< bias_; + std::cout << "\n"; + */ + + //std::cout << l1_penalty_ << l2_penalty_ << momentum_ << learn_rate_ << "\n"; + } + + private: + Matrix linearity_; + Vector bias_; + + Matrix linearity_corr_; + Vector bias_corr_; +}; + +} //namespace + +#endif diff --git a/src/nnet/nnet-component.cc b/src/nnet/nnet-component.cc new file mode 100644 index 000000000..0ef2dae48 --- /dev/null +++ b/src/nnet/nnet-component.cc @@ -0,0 +1,98 @@ +// nnet/nnet-component.h + +// Copyright 2011 Karel Vesely + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet/nnet-component.h" + +#include "nnet/nnet-nnet.h" +#include "nnet/nnet-activation.h" +#include "nnet/nnet-biasedlinearity.h" + +namespace kaldi { + + +const struct Component::key_value Component::kMarkerMap[] = { + { Component::kBiasedLinearity,"" }, + { Component::kSigmoid,"" }, + { Component::kSoftmax,"" } +}; + + +const char* Component::TypeToMarker(ComponentType t) { + int32 N=sizeof(kMarkerMap)/sizeof(kMarkerMap[0]); + for(int i=0; iReadData(is,binary); + return p_comp; +} + + +void Component::Write(std::ostream& os, bool binary) const { + WriteMarker(os,binary,Component::TypeToMarker(GetType())); + WriteBasicType(os,binary,OutputDim()); + WriteBasicType(os,binary,InputDim()); + if(!binary) os << "\n"; + this->WriteData(os,binary); +} + + +} // namespace diff --git a/src/nnet/nnet-component.h b/src/nnet/nnet-component.h new file mode 100644 index 000000000..d6e661cb3 --- /dev/null +++ b/src/nnet/nnet-component.h @@ -0,0 +1,273 @@ +// nnet/nnet-component.h + +// Copyright 2011 Karel Vesely + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + + +#ifndef KALDI_NNET_COMPONENT_H +#define KALDI_NNET_COMPONENT_H + + +#include "base/kaldi-common.h" +#include "matrix/matrix-lib.h" +//#include "nnet/nnet-nnet.h" + +#include + +namespace kaldi { + +//declare the nnet class so we can declare pointer +class Nnet; + + +/** + * Abstract class, basic element of the network, + * it is a box with defined inputs, outputs, + * and tranformation functions interface. + * + * It is able to propagate and backpropagate + * exact implementation is to be implemented in descendants. + * + * The data buffers are not included + * and will be managed from outside. + */ +class Component +{ + ////////////////////////////////////////////////////////////// + // Disable copy construction and assignment + private: + Component(Component&); + Component& operator=(Component&); + + ////////////////////////////////////////////////////////////// + // Polymorphic Component RTTI + public: + /// Types of the net components + typedef enum { + kUnknown = 0x0, + + kUpdatableComponent = 0x0100, + kBiasedLinearity, + kSharedLinearity, + + kActivationFunction = 0x0200, + kSoftmax, + kSigmoid, + + kTranform = 0x0400, + kExpand, + kCopy, + kTranspose, + kBlockLinearity, + kBias, + kWindow, + kLog + } ComponentType; + /// Pair of type and marker + struct key_value { + const Component::ComponentType key; + const char* value; + }; + /// Mapping of types and markers + static const struct key_value kMarkerMap[]; + /// Convert component type to marker + static const char* TypeToMarker(ComponentType t); + /// Convert marker to component type + static ComponentType MarkerToType(const std::string& s); + + + ////////////////////////////////////////////////////////////// + // Constructor & Destructor + public: + Component(MatrixIndexT input_dim, MatrixIndexT output_dim, Nnet* nnet) + : input_dim_(input_dim), output_dim_(output_dim), nnet_(nnet) + { } + virtual ~Component() + { } + + ////////////////////////////////////////////////////////////// + // Public interface + public: + /// Get Type Identification of the component + virtual ComponentType GetType() const = 0; + /// Check if contains trainable parameters + virtual bool IsUpdatable() const { + return false; + } + + /// Get size of input vectors + MatrixIndexT InputDim() const { + return input_dim_; + } + /// Get size of output vectors + MatrixIndexT OutputDim() const { + return output_dim_; + } + + /// Perform forward pass propagateion Input->Output + void Propagate(const Matrix& in, Matrix* out); + /// Perform backward pass propagateion ErrorInput->ErrorOutput + void Backpropagate(const Matrix& in_err, Matrix* out_err); + + /// Read component from stream + static Component* Read(std::istream& is, bool binary, Nnet* nnet); + /// Write component to stream + void Write(std::ostream& os, bool binary) const; + + + /////////////////////////////////////////////////////////////// + // abstract interface for propagation/backpropagation + protected: + /// Forward pass transformation (to be implemented by descendents...) + virtual void PropagateFnc(const Matrix& in, Matrix* out) = 0; + /// Backward pass transformation (to be implemented by descendents...) + virtual void BackpropagateFnc(const Matrix& in_err, Matrix* out_err) = 0; + + /// Reads the component content + virtual void ReadData(std::istream& is, bool binary) + { } + /// Writes the component content + virtual void WriteData(std::ostream& os, bool binary) const + { } + + + /////////////////////////////////////////////////////////////// + // data members + protected: + MatrixIndexT input_dim_; ///< Size of input vectors + MatrixIndexT output_dim_; ///< Size of output vectors + + Nnet* nnet_; ///< Pointer to the whole network + +}; + + +/** + * Class UpdatableComponent is a Component which has + * trainable parameters and contains some global + * parameters for stochastic gradient descent + * (learnrate,momenutm,L2,L1) + */ +class UpdatableComponent : public Component { + ////////////////////////////////////////////////////////////// + // Constructor & Destructor + public: + UpdatableComponent(MatrixIndexT input_dim, MatrixIndexT output_dim, Nnet* nnet) + : Component(input_dim,output_dim,nnet), + learn_rate_(0.0), momentum_(0.0), l2_penalty_(0.0), l1_penalty_(0.0) + { } + virtual ~UpdatableComponent() + { } + + + ////////////////////////////////////////////////////////////// + // Public interface + public: + /// Check if contains trainable parameters + bool IsUpdatable() const { + return true; + } + + /// Compute gradient and update parameters + virtual void Update(const Matrix& input, const Matrix& err) = 0; + + /// Sets the learning rate of gradient descent + void LearnRate(BaseFloat lrate) { + learn_rate_ = lrate; + } + /// Gets the learning rate of gradient descent + BaseFloat LearnRate() { + return learn_rate_; + } + + /// Sets momentum + void Momentum(BaseFloat mmt) { + momentum_ = mmt; + } + /// Gets momentum + BaseFloat Momentum() { + return momentum_; + } + + /// Sets L2 penalty (weight decay) + void L2Penalty(BaseFloat l2) { + l2_penalty_ = l2; + } + /// Gets L2 penalty (weight decay) + BaseFloat L2Penalty() { + return l2_penalty_; + } + + /// Sets L1 penalty (sparisity promotion) + void L1Penalty(BaseFloat l1) { + l1_penalty_ = l1; + } + /// Gets L1 penalty (sparisity promotion) + BaseFloat L1Penalty() { + return l1_penalty_; + } + + protected: + BaseFloat learn_rate_; ///< learning rate (0.0..0.01) + BaseFloat momentum_; ///< momentum value (0.0..1.0) + BaseFloat l2_penalty_; ///< L2 regularization constant (0.0..1e-4) + BaseFloat l1_penalty_; ///< L1 regularization constant (0.0..1e-4) +}; + + + + +////////////////////////////////////////////////////////////////////////// +// INLINE FUNCTIONS +// Component:: +inline void Component::Propagate(const Matrix& in, Matrix* out) { + if(input_dim_ != in.NumCols()) { + KALDI_ERR << "Nonmatching dims, component:" << input_dim_ << " data:" << in.NumCols(); + } + + if(output_dim_ != out->NumCols() || in.NumRows() != out->NumRows()) { + out->Resize(in.NumRows(), output_dim_); + } + + PropagateFnc(in, out); +} + + +inline void Component::Backpropagate(const Matrix& in_err, Matrix* out_err) { + if(output_dim_ != in_err.NumCols()) { + KALDI_ERR << "Nonmatching dims, component:" << output_dim_ + << " data:" << in_err.NumCols(); + } + + if(input_dim_ != out_err->NumCols() || in_err.NumRows() != out_err->NumRows()) { + out_err->Resize(in_err.NumRows(), input_dim_); + } + + BackpropagateFnc(in_err, out_err); +} + +////////////////////////////////////////////////////////////////////////// +// INLINE FUNCTIONS +// UpdatableComponent:: + +// nothing for now! + + + +} // namespace kaldi + + +#endif diff --git a/src/nnet/nnet-loss.cc b/src/nnet/nnet-loss.cc new file mode 100644 index 000000000..02300bbbe --- /dev/null +++ b/src/nnet/nnet-loss.cc @@ -0,0 +1,151 @@ +// nnet/nnet-loss.cc + +// Copyright 2011 Karel Vesely + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet/nnet-loss.h" + +#include + +namespace kaldi { + + +void Xent::Eval(const Matrix& net_out, const Matrix& target, Matrix* diff) { + KALDI_ASSERT(net_out.NumCols() == target.NumCols()); + KALDI_ASSERT(net_out.NumRows() == target.NumRows()); + diff->Resize(net_out.NumRows(),net_out.NumCols(),kUndefined); + + //compute derivative wrt. activations of last layer of neurons + diff->CopyFromMat(net_out); + diff->AddMat(-1.0,target); + + //we'll not produce per-frame classification accuracy for soft labels + correct_ = -1; + + //compute xentropy + BaseFloat val; + for(int32 r=0; r& net_out, const std::vector& target, Matrix* diff) { + KALDI_ASSERT(net_out.NumRows() == (int32)target.size()); + + //check the labels + int32 max=0; + std::vector::const_iterator it; + for(it=target.begin(); it!=target.end(); ++it) { + if(max < *it) max = *it; + } + if(max > net_out.NumCols()) { + KALDI_ERR << "Network has " << net_out.NumCols() + << " outputs while having " << max << " labels"; + } + + //compute derivative wrt. activations of last layer of neurons + diff->Resize(net_out.NumRows(),net_out.NumCols(),kUndefined); + diff->CopyFromMat(net_out); + for(int32 r=0; r<(int32)target.size(); r++) { + KALDI_ASSERT(target.at(r) <= diff->NumCols()); + (*diff)(r,target.at(r)-1) -= 1.0; + } + + //we'll not produce per-frame classification accuracy for soft labels + correct_ += Correct(net_out,target); + + //compute xentropy + BaseFloat val; + for(int32 r=0; r= 0.0) { + oss << " correct[" << 100.0*correct_/frames_ << "%]"; + } + oss << std::endl; + return oss.str(); +} + + +int32 Xent::Correct(const Matrix& net_out, const std::vector& target) { + int32 correct = 0; + for(int32 r=0; r& net_out, const Matrix& target, Matrix* diff) { + KALDI_ASSERT(net_out.NumCols() == target.NumCols()); + KALDI_ASSERT(net_out.NumRows() == target.NumRows()); + diff->Resize(net_out.NumRows(),net_out.NumCols(),kUndefined); + + //compute derivative w.r.t. neural nerwork outputs + diff->CopyFromMat(net_out); + diff->AddMat(-1.0,target); + + //compute mean square error + BaseFloat val; + for(int32 r=0; r& net_out, const Matrix& target, Matrix* diff); + /// Evaluate cross entropy from soft labels + void Eval(const Matrix& net_out, const std::vector& target, Matrix* diff); + + /// Generate string with error report + std::string Report(); + + private: + int32 Correct(const Matrix& net_out, const std::vector& target); + + private: + int32 frames_; + int32 correct_; + double loss_; +}; + +class Mse { + public: + Mse() + : frames_(0), loss_(0.0) + { } + ~Mse() + { } + + /// Evaluate mean square error from target values + void Eval(const Matrix& net_out, const Matrix& target, Matrix* diff); + + /// Generate string with error report + std::string Report(); + + private: + int32 frames_; + double loss_; +}; + + + +} // namespace + +#endif + diff --git a/src/nnet/nnet-nnet.cc b/src/nnet/nnet-nnet.cc new file mode 100644 index 000000000..efb17308c --- /dev/null +++ b/src/nnet/nnet-nnet.cc @@ -0,0 +1,217 @@ +// nnet/nnet-nnet.cc + +// Copyright 2011 Karel Vesely + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet/nnet-nnet.h" +#include "nnet/nnet-component.h" +#include "nnet/nnet-activation.h" +#include "nnet/nnet-biasedlinearity.h" + +namespace kaldi { + +void Nnet::Propagate(const Matrix& in, Matrix* out) { + KALDI_ASSERT(NULL != out); + + if(LayerCount() == 0) { + out->Resize(in.NumRows(),in.NumCols(),kUndefined); + out->CopyFromMat(in); + return; + } + + //we need at least L+1 input buffers + KALDI_ASSERT((int32)propagate_buf_.size() >= LayerCount()+1); + + + propagate_buf_[0].Resize(in.NumRows(),in.NumCols(),kUndefined); + propagate_buf_[0].CopyFromMat(in); + + for(int32 i=0; i<(int32)nnet_.size(); i++) { + nnet_[i]->Propagate(propagate_buf_[i],&propagate_buf_[i+1]); + } + + Matrix& mat = propagate_buf_[nnet_.size()]; + out->Resize(mat.NumRows(),mat.NumCols(),kUndefined); + out->CopyFromMat(mat); +} + + +void Nnet::Backpropagate(const Matrix& in_err, Matrix* out_err) { + if(LayerCount() == 0) { KALDI_ERR << "Cannot backpropagate on empty network"; } + + //we need at least L+1 input bufers + KALDI_ASSERT((int32)propagate_buf_.size() >= LayerCount()+1); + //we need at least L-1 error bufers + KALDI_ASSERT((int32)backpropagate_buf_.size() >= LayerCount()-1); + + //find out when we can stop backprop + int32 backprop_stop = -1; + if(NULL == out_err) { + backprop_stop++; + while(1) { + if(nnet_[backprop_stop]->IsUpdatable()) { + if(0.0 != dynamic_cast(nnet_[backprop_stop])->LearnRate()) { + break; + } + } + backprop_stop++; + if(backprop_stop == (int32)nnet_.size()) { + KALDI_ERR << "All layers have zero learning rate!"; + break; + } + } + } + //disable! + backprop_stop=-1; + + ////////////////////////////////////// + // Backpropagation + // + + //don't copy the in_err to buffers, use it as is... + int32 i = nnet_.size()-1; + if(nnet_[i]->IsUpdatable()) { + UpdatableComponent* uc = dynamic_cast(nnet_[i]); + if(uc->LearnRate() > 0.0) { + uc->Update(propagate_buf_[i],in_err); + } + } + nnet_.back()->Backpropagate(in_err,&backpropagate_buf_[i-1]); + + //backpropagate by using buffers + for(i--; i >= 1; i--) { + if(nnet_[i]->IsUpdatable()) { + UpdatableComponent* uc = dynamic_cast(nnet_[i]); + if(uc->LearnRate() > 0.0) { + uc->Update(propagate_buf_[i],backpropagate_buf_[i]); + } + } + if(backprop_stop == i) break; + nnet_[i]->Backpropagate(backpropagate_buf_[i],&backpropagate_buf_[i-1]); + } + + //update first layer + if(nnet_[0]->IsUpdatable() && 0 >= backprop_stop) { + UpdatableComponent* uc = dynamic_cast(nnet_[0]); + if(uc->LearnRate() > 0.0) { + uc->Update(propagate_buf_[0],backpropagate_buf_[0]); + } + } + //now backpropagate through first layer, but only if asked to (by out_err pointer) + if(NULL != out_err) { + nnet_[0]->Backpropagate(backpropagate_buf_[0],out_err); + } + + // + // End of Backpropagation + ////////////////////////////////////// +} + + +void Nnet::Feedforward(const Matrix& in, Matrix* out) { + KALDI_ASSERT(NULL != out); + + if(LayerCount() == 0) { + out->Resize(in.NumRows(),in.NumCols(),kUndefined); + out->CopyFromMat(in); + return; + } + + //we need at least 2 input buffers + KALDI_ASSERT(propagate_buf_.size() >= 2); + + //propagate by using exactly 2 auxiliary buffers + int32 L = 0; + nnet_[L]->Propagate(in,&propagate_buf_[L%2]); + for(L++; L<=LayerCount()-2; L++) { + nnet_[L]->Propagate(propagate_buf_[(L-1)%2],&propagate_buf_[L%2]); + } + nnet_[L]->Propagate(propagate_buf_[(L-1)%2],out); +} + + +void Nnet::Read(std::istream& in, bool binary) { + //get the network layers from a factory + Component *comp; + while(NULL != (comp = Component::Read(in,binary,this))) { + if(LayerCount() > 0 && nnet_.back()->OutputDim() != comp->InputDim()) { + KALDI_ERR << "Dimensionality mismatch!" + << " Previous layer output:" << nnet_.back()->OutputDim() + << " Current layer input:" << comp->InputDim(); + } + nnet_.push_back(comp); + } + //create empty buffers + propagate_buf_.resize(LayerCount()+1); + backpropagate_buf_.resize(LayerCount()-1); + //reset learn rate + learn_rate_ = 0.0; +} + + +void Nnet::LearnRate(BaseFloat lrate, const char* lrate_factors) { + //split lrate_factors to a vector + std::vector lrate_factor_vec; + if(NULL != lrate_factors) { + char* copy = new char[strlen(lrate_factors)+1]; + strcpy(copy, lrate_factors); + char* tok = NULL; + while(NULL != (tok = strtok((tok==NULL?copy:NULL),",:; "))) { + lrate_factor_vec.push_back(atof(tok)); + } + delete copy; + } + + //count trainable layers + int32 updatable = 0; + for(int i=0; iIsUpdatable()) updatable++; + } + //check number of factors + if(lrate_factor_vec.size() > 0 && updatable != (int32)lrate_factor_vec.size()) { + KALDI_ERR << "Mismatch between number of trainable layers " << updatable + << " and learn rate factors " << lrate_factor_vec.size(); + } + + //set learn rates + updatable=0; + for(int32 i=0; iIsUpdatable()) { + BaseFloat lrate_scaled = lrate; + if(lrate_factor_vec.size() > 0) lrate_scaled *= lrate_factor_vec[updatable++]; + dynamic_cast(nnet_[i])->LearnRate(lrate_scaled); + } + } + //set global learn rate + learn_rate_ = lrate; +} + + +std::string Nnet::LearnRateString() { + std::ostringstream oss; + oss << "LEARN_RATE global: " << learn_rate_ << " individual: "; + for(int32 i=0; iIsUpdatable()) { + oss << dynamic_cast(nnet_[i])->LearnRate() << " "; + } + } + return oss.str(); +} + + + + + +} // namespace diff --git a/src/nnet/nnet-nnet.h b/src/nnet/nnet-nnet.h new file mode 100644 index 000000000..cdc7bf204 --- /dev/null +++ b/src/nnet/nnet-nnet.h @@ -0,0 +1,223 @@ +// nnet/nnet-nnet.h + +// Copyright 2011 Karel Vesely + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + + +#ifndef KALDI_NNET_NNET_H +#define KALDI_NNET_NNET_H + + +#include "base/kaldi-common.h" +#include "util/kaldi-io.h" +#include "matrix/matrix-lib.h" +#include "nnet/nnet-component.h" + +#include +#include +#include + + +namespace kaldi { + +class Nnet { + ////////////////////////////////////// + // Typedefs + typedef std::vector NnetType; + + ////////////////////////////////////////////////////////////// + // Disable copy construction and assignment + private: + Nnet(Nnet&); + Nnet& operator=(Nnet&); + + ////////////////////////////////////////////////////////////// + // Constructor & Destructor + public: + Nnet() + { } + + ~Nnet(); //{ } later... + + ////////////////////////////////////////////////////////////// + // Public interface + public: + /// Perform forward pass through the network + void Propagate(const Matrix& in, Matrix* out); + /// Perform backward pass through the network + void Backpropagate(const Matrix& in_err, Matrix* out_err); + /// Perform forward pass through the network, don't keep buffers (use it when not training) + void Feedforward(const Matrix& in, Matrix* out); + + MatrixIndexT InputDim() const; ///< Dimensionality of the input features + MatrixIndexT OutputDim() const; ///< Dimensionality of the desired vectors + + MatrixIndexT LayerCount() const { ///< Get number of layers + return nnet_.size(); + } + Component* Layer(MatrixIndexT index) { ///< Access to individual layer + return nnet_[index]; + } + int IndexOfLayer(const Component& comp) const; ///< Get the position of layer in network + + /// Access to forward pass buffers + const std::vector >& PropagateBuffer() const { + return propagate_buf_; + } + + /// Access to backward pass buffers + const std::vector >& BackpropagateBuffer() const { + return backpropagate_buf_; + } + + /// Read the MLP from file (can add layers to exisiting instance of Nnet) + void Read(const std::string& file); + /// Read the MLP from stream (can add layers to exisiting instance of Nnet) + void Read(std::istream& in, bool binary); + /// Write MLP to file + void Write(const std::string& file, bool binary); + /// Write MLP to stream + void Write(std::ostream& out, bool binary); + + /// Set the learning rate values to trainable layers, + /// factors can disable training of individual layers + void LearnRate(BaseFloat lrate, const char* lrate_factors); + /// Get the global learning rate value + BaseFloat LearnRate() { + return learn_rate_; + } + /// Get the string with real learning rate values + std::string LearnRateString(); + + void Momentum(BaseFloat mmt); + void L2Penalty(BaseFloat l2); + void L1Penalty(BaseFloat l1); + + ////////////////////////////////////////////////////////////// + // Private interface + private: + /// Creates a component by reading from stream, return NULL if no more components + static Component* ComponentFactory(std::istream& in, bool binary, Nnet* nnet); + /// Dumps individual component to stream + static void ComponentDumper(std::ostream& out, bool binary, const Component& comp); + + private: + NnetType nnet_; ///< vector of all Component*, represents layers + + std::vector > propagate_buf_; ///< buffers for forward pass + std::vector > backpropagate_buf_; ///< buffers for backward pass + + BaseFloat learn_rate_; ///< global learning rate +}; + + +////////////////////////////////////////////////////////////////////////// +// INLINE FUNCTIONS +// Nnet:: +inline Nnet::~Nnet() { + //delete all the components + NnetType::iterator it; + for(it=nnet_.begin(); it!=nnet_.end(); ++it) { + delete *it; + } +} + + +inline MatrixIndexT Nnet::InputDim() const { + if(LayerCount() > 0) { + return nnet_.front()->InputDim(); + } else { + KALDI_ERR << "No layers in MLP"; + } +} + + +inline MatrixIndexT Nnet::OutputDim() const { + if(LayerCount() > 0) { + return nnet_.back()->OutputDim(); + } else { + KALDI_ERR << "No layers in MLP"; + } +} + + +inline int32 Nnet::IndexOfLayer(const Component& comp) const { + for(int32 i=0; iWrite(out,binary); + } +} + + +inline void Nnet::Momentum(BaseFloat mmt) { + for(int32 i=0; iIsUpdatable()) { + dynamic_cast(nnet_[i])->Momentum(mmt); + } + } +} + + +inline void Nnet::L2Penalty(BaseFloat l2) { + for(int32 i=0; iIsUpdatable()) { + dynamic_cast(nnet_[i])->L2Penalty(l2); + } + } +} + + +inline void Nnet::L1Penalty(BaseFloat l1) { + for(int32 i=0; iIsUpdatable()) { + dynamic_cast(nnet_[i])->L1Penalty(l1); + } + } +} + + + + +} //namespace kaldi + +#endif + + diff --git a/src/nnet/nnet-test.cc b/src/nnet/nnet-test.cc new file mode 100644 index 000000000..35ebe7c82 --- /dev/null +++ b/src/nnet/nnet-test.cc @@ -0,0 +1,43 @@ +// nnet/nnet-test.cc + +// Copyright 2010 Karel Vesely + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include + +#include "base/kaldi-common.h" +#include "nnet/nnet-component.h" +#include "nnet/nnet-nnet.h" + +using namespace kaldi; + +static void UnitTestSomething() { + KALDI_ERR << "Unimeplemented"; +} + + +static void UnitTestNnet() { + try { + UnitTestSomething(); + } catch (const std::exception& e) { + std::cerr << e.what(); + } +} + + +int main() { + UnitTestNnet(); +} diff --git a/src/nnetbin/Makefile b/src/nnetbin/Makefile new file mode 100644 index 000000000..c35da40bb --- /dev/null +++ b/src/nnetbin/Makefile @@ -0,0 +1,44 @@ + +all: +EXTRA_CXXFLAGS = -Wno-sign-compare +include ../kaldi.mk + +BINFILES = nnet-train-xent-hardlab-perutt + +OBJFILES = + +all: $(BINFILES) + + +TESTFILES = + + +$(BINFILES): ../nnet/kaldi-nnet.a ../matrix/kaldi-matrix.a ../util/kaldi-util.a ../base/kaldi-base.a + + + +# Rule below would expand to, e.g.: +# ../base/kaldi-base.a: +# make -c ../base kaldi-base.a +# -c option to make is same as changing directory. +%.a: + $(MAKE) -C ${@D} ${@F} + +clean: + rm *.o *.a $(TESTFILES) $(BINFILES) + +test: $(TESTFILES) + for x in $(TESTFILES); do ./$$x >&/dev/null || (echo "***test $$x failed***"; exit 1); done + echo Tests succeeded + +.valgrind: $(TESTFILES) + + +depend: + -$(CXX) -M $(CXXFLAGS) *.cc > .depend.mk + +# removing automatic making of "depend" as it's quite slow. +#.depend.mk: depend + +-include .depend.mk + diff --git a/src/nnetbin/nnet-train-xent-hardlab-perutt.cc b/src/nnetbin/nnet-train-xent-hardlab-perutt.cc new file mode 100644 index 000000000..77733a7f8 --- /dev/null +++ b/src/nnetbin/nnet-train-xent-hardlab-perutt.cc @@ -0,0 +1,155 @@ +// nnet/nnet-train-xent-hardlab-perutt.cc + +// Copyright 2011 Karel Vesely + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet/nnet-nnet.h" +#include "nnet/nnet-loss.h" +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "util/timer.h" + + +int main(int argc, char *argv[]) +{ + using namespace kaldi; + try { + const char *usage = + "Perform iteration of Neural Network training by stochastic gradient descent.\n" + "Usage: nnet-train-xent-hardlab-perutt [options] []\n" + "e.g.: \n" + " nnet-train-xent-hardlab-perutt nnet.init scp:train.scp ark:train.ali nnet.iter1\n"; + + ParseOptions po(usage); + bool binary = false, + crossvalidate = false; + po.Register("binary", &binary, "Write output in binary mode"); + po.Register("cross-validate", &crossvalidate, "Perform cross-validation (don't backpropagate)"); + + BaseFloat learn_rate = 0.008, + momentum = 0.0, + l2_penalty = 0.0, + l1_penalty = 0.0; + + po.Register("learn-rate", &learn_rate, "Learning rate"); + po.Register("momentum", &momentum, "Momentum"); + po.Register("l2-penalty", &l2_penalty, "L2 penalty (weight decay)"); + po.Register("l1-penalty", &l1_penalty, "L1 penalty (promote sparsity)"); + + std::string feature_transform; + po.Register("feature-transform", &feature_transform, "Feature transform Neural Network"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4-(crossvalidate?1:0)) { + po.PrintUsage(); + exit(1); + } + + std::string model_filename = po.GetArg(1), + feature_rspecifier = po.GetArg(2), + alignments_rspecifier = po.GetArg(3); + + std::string target_model_filename; + if(!crossvalidate) { + target_model_filename = po.GetArg(4); + } + + + using namespace kaldi; + typedef kaldi::int32 int32; + + + Nnet nnet_transf; + if(feature_transform != "") { + nnet_transf.Read(feature_transform); + } + + Nnet nnet; + nnet.Read(model_filename); + + nnet.LearnRate(learn_rate,NULL); + nnet.Momentum(momentum); + nnet.L2Penalty(l2_penalty); + nnet.L1Penalty(l1_penalty); + + kaldi::int64 tot_t = 0; + + SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); + RandomAccessInt32VectorReader alignments_reader(alignments_rspecifier); + + Xent xent; + + Matrix feats_transf, nnet_out, glob_err; + + Timer tim; + KALDI_LOG << (crossvalidate?"CROSSVALIDATE":"TRAINING") << " STARTED"; + + int32 num_done = 0, num_no_alignment = 0, num_other_error = 0; + for (; !feature_reader.Done(); feature_reader.Next()) { + std::string key = feature_reader.Key(); + if (!alignments_reader.HasKey(key)) { + num_no_alignment++; + } else { + const Matrix &mat = feature_reader.Value(); + const std::vector &alignment = alignments_reader.Value(key); + + //std::cout << mat; + + if ((int32)alignment.size() != mat.NumRows()) { + KALDI_WARN << "Alignment has wrong size "<< (alignment.size()) << " vs. "<< (mat.NumRows()); + num_other_error++; + continue; + } + + if(num_done % 10000 == 0) std::cout << num_done << ", " << std::flush; + num_done++; + + nnet_transf.Feedforward(mat,&feats_transf); + nnet.Propagate(feats_transf,&nnet_out); + //std::cout << "\nNETOUT" << nnet_out; + xent.Eval(nnet_out,alignment,&glob_err); + //std::cout << "\nALIGN" << alignment[0] << " "<< alignment[1]<< " "<< alignment[2]; + //std::cout << "\nGLOBERR" << glob_err; + if(!crossvalidate) { + nnet.Backpropagate(glob_err,NULL); + } + + tot_t += mat.NumRows(); + } + } + + if(!crossvalidate) { + nnet.Write(target_model_filename,binary); + } + + std::cout << "\n" << std::flush; + + KALDI_LOG << (crossvalidate?"CROSSVALIDATE":"TRAINING") << " FINISHED " + << tim.Elapsed() << "s, fps" << tot_t/tim.Elapsed(); + + KALDI_LOG << "Done " << num_done << " files, " << num_no_alignment + << " with no alignments, " << num_other_error + << " with other errors."; + + KALDI_LOG << xent.Report(); + + + return 0; + } catch(const std::exception& e) { + std::cerr << e.what(); + return -1; + } +}