Adding the Neural Networks for phoneme-state classification

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@51 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
This commit is contained in:
Karel Vesely 2011-05-22 21:17:34 +00:00
Родитель 2fc231f3b9
Коммит 0769ce57ee
50 изменённых файлов: 3842 добавлений и 2 удалений

Просмотреть файл

@ -5,4 +5,7 @@ scripts for a sequence of experiments.
s1: This setup is experiments with GMM-based systems with various
Maximum Likelihood
techniques including global and speaker-specific transforms.
See a parallel setup in ../wsj/s1
See a parallel setup in ../wsj/s1
s2: This setup is experiment with hybrid MLP system trained by
stochastic gradient descent.

0
egs/rm/s2/NOTES Normal file
Просмотреть файл

0
egs/rm/s2/RESULTS Normal file
Просмотреть файл

1
egs/rm/s2/conf/mfcc.conf Normal file
Просмотреть файл

@ -0,0 +1 @@
--use-energy=false # only non-default option.

22
egs/rm/s2/conf/topo.proto Normal file
Просмотреть файл

@ -0,0 +1,22 @@
<Topology>
<TopologyEntry>
<ForPhones>
NONSILENCEPHONES
</ForPhones>
<State> 0 <PdfClass> 0 <Transition> 0 0.75 <Transition> 1 0.25 </State>
<State> 1 <PdfClass> 1 <Transition> 1 0.75 <Transition> 2 0.25 </State>
<State> 2 <PdfClass> 2 <Transition> 2 0.75 <Transition> 3 0.25 </State>
<State> 3 </State>
</TopologyEntry>
<TopologyEntry>
<ForPhones>
SILENCEPHONES
</ForPhones>
<State> 0 <PdfClass> 0 <Transition> 0 0.25 <Transition> 1 0.25 <Transition> 2 0.25 <Transition> 3 0.25 </State>
<State> 1 <PdfClass> 1 <Transition> 1 0.25 <Transition> 2 0.25 <Transition> 3 0.25 <Transition> 4 0.25 </State>
<State> 2 <PdfClass> 2 <Transition> 1 0.25 <Transition> 2 0.25 <Transition> 3 0.25 <Transition> 4 0.25 </State>
<State> 3 <PdfClass> 3 <Transition> 1 0.25 <Transition> 2 0.25 <Transition> 3 0.25 <Transition> 4 0.25 </State>
<State> 4 <PdfClass> 4 <Transition> 4 0.25 <Transition> 5 0.75 </State>
<State> 5 </State>
</TopologyEntry>
</Topology>

Просмотреть файл

@ -0,0 +1,69 @@
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# usage: make_trans.sh prefix in.flist input.snr out.txt out.scp
# prefix is first letters of the database "key" (rest are numeric)
# in.flist is just a list of filenames, probably of .sph files.
# input.snr is an snr format file from the RM dataset.
# out.txt is the output transcriptions in format "key word1 word\n"
# out.scp is the output scp file, which is as in.scp but has the
# database-key first on each line.
# Reads from first argument e.g. $rootdir/rm1_audio1/rm1/doc/al_sents.snr
# and second argument train_wav.scp
# Writes to standard output trans.txt
if(@ARGV != 5) {
die "usage: make_trans.sh prefix in.flist input.snr out.txt out.scp\n";
}
($prefix, $in_flist, $input_snr, $out_txt, $out_scp) = @ARGV;
open(F, "<$input_snr") || die "Opening SNOR file $input_snr";
while(<F>) {
if(m/^;/) { next; }
m/(.+) \((.+)\)/ || die "bad line $_";
$T{$2} = $1;
}
close(F);
open(G, "<$in_flist") || die "Opening file list $in_flist";
open(O, ">$out_txt") || die "Open output transcription file $out_txt";
open(P, ">$out_scp") || die "Open output scp file $out_scp";
while(<G>) {
$_ =~ m:/(\w+)/(\w+)\.sph\s+$:i || die "bad scp line $_";
$spkname = $1;
$uttname = $2;
$uttname =~ tr/a-z/A-Z/;
defined $T{$uttname} || die "no trans for sent $uttname";
$spkname =~ s/_//g; # remove underscore from spk name to make key nicer.
$key = $prefix . "_" . $spkname . "_" . $uttname;
$key =~ tr/A-Z/a-z/; # Make it all lower case.
# to make the numerical and string-sorted orders the same.
print O "$key $T{$uttname}\n";
print P "$key $_";
$n++;
}
close(O) || die "Closing output.";
close(P) || die "Closing output.";

92
egs/rm/s2/data_prep/run.sh Executable file
Просмотреть файл

@ -0,0 +1,92 @@
# This script should be run from the directory where it is located (i.e. data_prep)
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# The input is the 3 CDs from the LDC distribution of Resource Management.
# The script's argument is a directory which has three subdirectories:
# rm1_audio1 rm1_audio2 rm2_audio
if [ $# != 1 ]; then
echo "Usage: ./run.sh /path/to/RM"
exit 1;
fi
RMROOT=$1
if [ ! -d $RMROOT/rm1_audio1 -o ! -d $RMROOT/rm1_audio2 ]; then
echo "Error: run.sh requires a directory argument that contains rm1_audio1 and rm1_audio2"
exit 1;
fi
if [ ! -d $RMROOT/rm2_audio ]; then
echo "**Warning: $RMROOT/rm2_audio does not exist; won't create spk2gender.map file correctly***"
sleep 1
fi
(
find $RMROOT/rm1_audio1/rm1/ind_trn -iname '*.sph';
find $RMROOT/rm1_audio2/2_4_2/rm1/ind/dev_aug -iname '*.sph';
) | perl -ane ' m:/sa\d.sph:i || m:/sb\d\d.sph:i || print; ' > train_sph.flist
# make_trans.pl also creates the utterance id's and the kaldi-format scp file.
./make_trans.pl trn train_sph.flist $RMROOT/rm1_audio1/rm1/doc/al_sents.snr train_trans.txt train_sph.scp
mv train_trans.txt tmp; sort -k 1 tmp > train_trans.txt
mv train_sph.scp tmp; sort -k 1 tmp > train_sph.scp
sph2pipe=`cd ../../../..; echo $PWD/tools/sph2pipe_v2.5/sph2pipe`
if [ ! -f $sph2pipe ]; then
echo "Could not find the sph2pipe program at $sph2pipe";
exit 1;
fi
awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < train_sph.scp > train_wav.scp
cat train_wav.scp | perl -ane 'm/^(\w+_(\w+)\w_\w+) / || die; print "$1 $2\n"' > train.utt2spk
cat train.utt2spk | sort -k 2 | ../scripts/utt2spk_to_spk2utt.pl > train.spk2utt
for ntest in 1_mar87 2_oct87 4_feb89 5_oct89 6_feb91 7_sep92; do
n=`echo $ntest | cut -d_ -f 1`
test=`echo $ntest | cut -d_ -f 2`
root=$RMROOT/rm1_audio2/2_4_2
for x in `grep -v ';' $root/rm1/doc/tests/$ntest/${n}_indtst.ndx`; do
echo "$root/$x ";
done > test_${test}_sph.flist
done
# make_trans.pl also creates the utterance id's and the kaldi-format scp file.
for test in mar87 oct87 feb89 oct89 feb91 sep92; do
./make_trans.pl ${test} test_${test}_sph.flist $RMROOT/rm1_audio1/rm1/doc/al_sents.snr test_${test}_trans.txt test_${test}_sph.scp
mv test_${test}_trans.txt tmp; sort -k 1 tmp > test_${test}_trans.txt
mv test_${test}_sph.scp tmp; sort -k 1 tmp > test_${test}_sph.scp
awk '{printf("%s '$sph2pipe' -f wav %s |\n", $1, $2);}' < test_${test}_sph.scp > test_${test}_wav.scp
cat test_${test}_wav.scp | perl -ane 'm/^(\w+_(\w+)\w_\w+) / || die; print "$1 $2\n"' > test_${test}.utt2spk
cat test_${test}.utt2spk | sort -k 2 | ../scripts/utt2spk_to_spk2utt.pl > test_${test}.spk2utt
done
cat $RMROOT/rm1_audio2/2_5_1/rm1/doc/al_spkrs.txt \
$RMROOT/rm2_audio/3-1.2/rm2/doc/al_spkrs.txt | \
perl -ane 'tr/A-Z/a-z/;print;' | grep -v ';' | \
awk '{print $1, $2}' > spk2gender.map
../scripts/make_rm_lm.pl $RMROOT/rm1_audio1/rm1/doc/wp_gram.txt > G.txt
# Getting lexicon
../scripts/make_rm_dict.pl $RMROOT/rm1_audio2/2_4_2/score/src/rdev/pcdsril.txt > lexicon.txt
echo Succeeded.

Просмотреть файл

@ -0,0 +1,39 @@
#!/bin/bash
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
fake=false
if [ "$1" == "--fake" ]; then
fake=true
shift
fi
sphdir=$1 # e.g. /mnt/matylda2/data/RM
wavdir=$2 # e.g. /mnt/matylda6/jhu09/qpovey/kaldi_rm_wav
flistin=$3 # e.g. train_sph.flist, contains sph files in sphdir
flistout=$4 # e.g. train_wav.flist, contains wav files in wavdir
if [ $fake == false ]; then
for x in `cat $flistin`; do
y=`echo $x | sed s:$sphdir:$wavdir: | sed s:.sph:.wav:`;
mkdir -p `dirname $y`
../../tools/sph2pipe_v2.5/sph2pipe -f wav $x $y || exit 1;
done
fi
cat $flistin | sed s:$sphdir:$wavdir: | sed s:.sph:.wav: > $flistout || exit 1;

1
egs/rm/s2/path.sh Executable file
Просмотреть файл

@ -0,0 +1 @@
export PATH=$PATH:../../../src/bin:../../../tools/openfst/bin:../../../src/fstbin/:../../../src/gmmbin/:../../../src/featbin/:../../../src/fgmmbin:../../../src/sgmmbin:../../../src/nnetbin

71
egs/rm/s2/run.sh Normal file
Просмотреть файл

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
#exit 1 # Don't run this... it's to be run line by line from the shell.
# This script file cannot be run as-is; some paths in it need to be changed
# before you can run it.
# Search for /path/to.
# It is recommended that you do not invoke this file from the shell, but
# run the paths one by one, by hand.
# the step in data_prep/ will need to be modified for your system.
# First step is to do data preparation:
# This just creates some text files, it is fast.
# If not on the BUT system, you would have to change run.sh to reflect
# your own paths.
#
#Example arguments to run.sh: /mnt/matylda2/data/RM, /ais/gobi2/speech/RM, /cygdrive/e/data/RM
# RM is a directory with subdirectories rm1_audio1, rm1_audio2, rm2_audio
cd data_prep
#*** You have to change the pathname below.***
./run.sh /path/to/RM
cd ..
mkdir -p data
( cd data; cp ../data_prep/{train,test*}.{spk2utt,utt2spk} . ; cp ../data_prep/spk2gender.map . )
# This next step converts the lexicon, grammar, etc., into FST format.
steps/prepare_graphs.sh
# Next, make sure that "exp/" is someplace you can write a significant amount of
# data to (e.g. make it a link to a file on some reasonably large file system).
# If it doesn't exist, the scripts below will make the directory "exp".
# mfcc should be set to some place to put training mfcc's
# where you have space.
#e.g.: mfccdir=/mnt/matylda6/jhu09/qpovey/kaldi_rm_mfccb
mfccdir=/path/to/mfccdir
steps/make_mfcc_train.sh $mfccdir
steps/make_mfcc_test.sh $mfccdir
# first, we will train monophone GMM system to get training labels
steps/train_mono.sh
steps/decode_mono.sh &
# Now we train the MLP,
# it will have CMVN normalized MFCCs as input and phoneme-state posteriors as output
steps/train_nnet.sh
#steps/decode_nnet.sh #TODO

Просмотреть файл

@ -0,0 +1,58 @@
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Adds some specified number of disambig symbols to a symbol table.
# Adds these as #1, #2, etc.
# If the --include-zero option is specified, includes an extra one
# #0.
if(!(@ARGV == 2 || (@ARGV ==3 && $ARGV[0] eq "--include-zero"))) {
die "Usage: add_disambig.pl [--include-zero] symtab.txt num_extra > symtab_out.txt ";
}
if(@ARGV == 3) {
$include_zero = 1;
$ARGV[0] eq "--include-zero" || die "Bad option/first argument $ARGV[0]";
shift @ARGV;
} else {
$include_zero = 0;
}
$input = $ARGV[0];
$nsyms = $ARGV[1];
open(F, "<$input") || die "Opening file $input";
while(<F>) {
@A = split(" ", $_);
@A == 2 || die "Bad line $_";
$lastsym = $A[1];
print;
}
if(!defined($lastsym)){
die "Empty symbol file?";
}
if($include_zero) {
$lastsym++;
print "#0 $lastsym\n";
}
for($n = 1; $n <= $nsyms; $n++) {
$y = $n + $lastsym;
print "#$n $y\n";
}

Просмотреть файл

@ -0,0 +1,101 @@
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Adds disambiguation symbols to a lexicon.
# Outputs still in the normal lexicon format.
# Disambig syms are numbered #1, #2, #3, etc. (#0
# reserved for symbol in grammar).
# Outputs the number of disambig syms to the standard output.
if(@ARGV != 2) {
die "Usage: add_lex_disambig.pl lexicon.txt lexicon_disambig.txt "
}
$lexfn = shift @ARGV;
$lexoutfn = shift @ARGV;
open(L, "<$lexfn") || die "Error opening lexicon $lexfn";
# (1) Read in the lexicon.
@L = ( );
while(<L>) {
@A = split(" ", $_);
push @L, join(" ", @A);
}
# (2) Work out the count of each phone-sequence in the
# lexicon.
foreach $l (@L) {
@A = split(" ", $l);
shift @A; # Remove word.
$count{join(" ",@A)}++;
}
# (3) For each left sub-sequence of each phone-sequence, note down
# that exists (for identifying prefixes of longer strings).
foreach $l (@L) {
@A = split(" ", $l);
shift @A; # Remove word.
while(@A > 0) {
pop @A; # Remove last phone
$issubseq{join(" ",@A)} = 1;
}
}
# (4) For each entry in the lexicon:
# if the phone sequence is unique and is not a
# prefix of another word, no diambig symbol.
# Else output #1, or #2, #3, ... if the same phone-seq
# has already been assigned a disambig symbol.
open(O, ">$lexoutfn") || die "Opening lexicon file $lexoutfn for writing.\n";
$max_disambig = 0;
foreach $l (@L) {
@A = split(" ", $l);
$word = shift @A;
$phnseq = join(" ",@A);
if(!defined $issubseq{$phnseq}
&& $count{$phnseq}==1) {
; # Do nothing.
} else {
if($phnseq eq "") { # need disambig symbols for the empty string
# that are not use anywhere else.
$max_disambig++;
$reserved{$max_disambig} = 1;
$phnseq = "#$max_disambig";
} else {
$curnumber = $disambig_of{$phnseq};
if(!defined{$curnumber}) { $curnumber = 0; }
$curnumber++; # now 1 or 2, ...
while(defined $reserved{$curnumber} ) { $curnumber++; } # skip over reserved symbols
if($curnumber > $max_disambig) {
$max_disambig = $curnumber;
}
$disambig_of{$phnseq} = $curnumber;
$phnseq = $phnseq . " #" . $curnumber;
}
}
print O "$word\t$phnseq\n";
}
print $max_disambig . "\n";

40
egs/rm/s2/scripts/filter_scp.pl Executable file
Просмотреть файл

@ -0,0 +1,40 @@
#!/usr/bin/perl -w
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script takes a list of utterance-ids and filters an scp
# file (or any file whose first field is an utterance id), printing
# out only those lines whose first field is in id_list.
if(@ARGV < 1 || @ARGV > 2) {
die "Usage: filter_scp.pl id_list [in.scp] > out.scp ";
}
$idlist = shift @ARGV;
open(F, "<$idlist") || die "Could not open id-list file $idlist";
while(<F>) {
@A = split;
@A>=1 || die "Invalid id-list file line $_";
$seen{$A[0]} = 1;
}
while(<>) {
@A = split;
@A > 0 || die "Invalid scp file line $_";
if($seen{$A[0]}) {
print $_;
}
}

Просмотреть файл

@ -0,0 +1,75 @@
#!/usr/bin/python -u
# ./gen_hamm_dct.py
# script generateing NN initialization for training with TNet
#
# author: Karel Vesely
#
import math, random
import sys
from optparse import OptionParser
parser = OptionParser()
parser.add_option('--dim', dest='dim', help='d1:d2:d3 layer dimensions in the network')
parser.add_option('--gauss', dest='gauss', help='use gaussian noise for weights', action='store_true', default=False)
parser.add_option('--negbias', dest='negbias', help='use uniform [-4.1,-3.9] for bias (defaultall 0.0)', action='store_true', default=False)
parser.add_option('--inputscale', dest='inputscale', help='scale the weights by 3/sqrt(Ninputs)', action='store_true', default=False)
parser.add_option('--linBNdim', dest='linBNdim', help='dim of linear bottleneck (sigmoids will be omitted, bias will be zero)',default=0)
(options, args) = parser.parse_args()
if(options.dim == None):
parser.print_help()
sys.exit(1)
dimStrL = options.dim.split(':')
dimL = []
for i in range(len(dimStrL)):
dimL.append(int(dimStrL[i]))
#print dimL,'linBN',options.linBNdim
for layer in range(len(dimL)-1):
print '<biasedlinearity>', dimL[layer+1], dimL[layer]
#weight matrix
print '['
for row in range(dimL[layer+1]):
for col in range(dimL[layer]):
if(options.gauss):
if(options.inputscale):
print 3/math.sqrt(dimL[layer])*random.gauss(0.0,1.0),
else:
print 0.1*random.gauss(0.0,1.0),
else:
if(options.inputscale):
print (random.random()-0.5)*2*3/math.sqrt(dimL[layer]),
else:
print random.random()/5.0-0.1,
print #newline for each row
print ']'
#bias vector
print '[',
for idx in range(dimL[layer+1]):
if(int(options.linBNdim) == dimL[layer+1]):
print '0.0',
elif(options.negbias):
print random.random()/5.0-4.1,
else:
print '0.0',
print ']'
if(int(options.linBNdim) != dimL[layer+1]):
if(layer == len(dimL)-2):
print '<softmax>', dimL[layer+1], dimL[layer+1]
else:
print '<sigmoid>', dimL[layer+1], dimL[layer+1]

69
egs/rm/s2/scripts/int2sym.pl Executable file
Просмотреть файл

@ -0,0 +1,69 @@
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
$ignore_noninteger = 0;
$ignore_first_field = 0;
for($x = 0; $x < 2; $x++) {
if($ARGV[0] eq "--ignore-noninteger") { $ignore_oov = 1; shift @ARGV; }
if($ARGV[0] eq "--ignore-first-field") { $ignore_first_field = 1; shift @ARGV; }
}
$symtab = shift @ARGV;
if(!defined $symtab) {
die "Usage: sym2int.pl symtab [input transcriptions] > output transcriptions\n";
}
open(F, "<$symtab") || die "Error opening symbol table file $symtab";
while(<F>) {
@A = split(" ", $_);
@A == 2 || die "bad line in symbol table file: $_";
$int2sym{$A[1]} = $A[0];
}
$error = 0;
while(<>) {
@A = split(" ", $_);
if(@A == 0) {
die "Empty line in transcriptions input.";
}
if($ignore_first_field) {
$key = shift @A;
print $key . " ";
}
foreach $a (@A) {
if($a !~ m:^\d+$:) { # not all digits..
if($ignore_noninteger) {
print $a . " ";
next;
} else {
if($a eq $A[0]) {
die "int2sym.pl: found noninteger token $a (try --ignore-first-field)\n";
} else {
die "int2sym.pl: found noninteger token $a (try --ignore-noninteger if valid input)\n";
}
}
}
$s = $int2sym{$a};
if(!defined ($s)) {
die "int2sym.pl: integer $a not in symbol table $symtab.";
}
print $s . " ";
}
print "\n";
}

45
egs/rm/s2/scripts/is_sorted.sh Executable file
Просмотреть файл

@ -0,0 +1,45 @@
#!/bin/bash
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Usage: is_sorted.sh [script-file]
# This script returns 0 (success) if the script file argument [or standard input]
# is sorted and 1 otherwise.
export LC_ALL=C
if [ $# == 0 ]; then
scp=-
fi
if [ $# == 1 ]; then
scp=$1
fi
if [ $# -gt 1 -o "$1" == "--help" -o "$1" == "-h" ]; then
echo "Usage: is_sorted.sh [script-file]"
exit 1
fi
cat $scp > /tmp/tmp1.$$
sort /tmp/tmp1.$$ > /tmp/tmp2.$$
cmp /tmp/tmp1.$$ /tmp/tmp2.$$ >/dev/null
ret=$?
rm /tmp/tmp1.$$ /tmp/tmp2.$$
if [ $ret == 0 ]; then
exit 0;
else
echo "is_sorted.sh: script file $scp is not sorted";
exit 1;
fi

Просмотреть файл

@ -0,0 +1,112 @@
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# makes lexicon FST (no pron-probs involved).
if(@ARGV != 1 && @ARGV != 3) {
die "Usage: make_lexicon_fst.pl lexicon.txt [silprob silphone] > lexiconfst.txt"
}
$lexfn = shift @ARGV;
if(@ARGV == 0) {
$silprob = 0.0;
} else {
($silprob,$silphone) = @ARGV;
}
if($silprob != 0.0) {
$silprob < 1.0 || die "Sil prob cannot be >= 1.0";
$silcost = -log($silprob);
$nosilcost = -log(1.0 - $silprob);
}
open(L, "<$lexfn") || die "Error opening lexicon $lexfn";
if( $silprob == 0.0 ) { # No optional silences: just have one (loop+final) state which is numbered zero.
$loopstate = 0;
$nexststate = 1; # next unallocated state.
while(<L>) {
@A = split(" ", $_);
$w = shift @A;
if(@A == 0) { # For empty words (<s> and </s>) insert no optional
# silence (not needed as adjacent words supply it)....
# actually we only hit this case for the lexicon without disambig
# symbols but doesn't ever matter as training transcripts don't have <s> or </s>.
print "$loopstate\t$loopstate\t<eps>\t$w\n";
} else {
$s = $loopstate;
$word_or_eps = $w;
while (@A > 0) {
$p = shift @A;
if(@A > 0) {
$ns = $nextstate++;
} else {
$ns = $loopstate;
}
print "$s\t$ns\t$p\t$word_or_eps\n";
$word_or_eps = "<eps>";
$s = $ns;
}
}
}
print "$loopstate\t0\n"; # final-cost.
} else { # have silence probs.
$startstate = 0;
$loopstate = 1;
$silstate = 2; # state from where we go to loopstate after emitting silence.
$nextstate = 3;
print "$startstate\t$loopstate\t<eps>\t<eps>\t$nosilcost\n"; # no silence.
print "$startstate\t$loopstate\t$silphone\t<eps>\t$silcost\n"; # silence.
print "$silstate\t$loopstate\t$silphone\t<eps>\n"; # no cost.
while(<L>) {
@A = split(" ", $_);
$w = shift @A;
if(@A == 0) { # For empty words (<s> and </s>) insert no optional
# silence (not needed as adjacent words supply it)....
# actually we only hit this case for the lexicon without disambig
# symbols but doesn't ever matter as training transcripts don't have <s> or </s>.
print "$loopstate\t$loopstate\t<eps>\t$w\n";
} else {
$is_silence_word = (@A == 1 && $A[0] eq $silphone); # boolean.
$s = $loopstate;
$word_or_eps = $w;
while (@A > 0) {
$p = shift @A;
if(@A > 0) {
$ns = $nextstate++;
print "$s\t$ns\t$p\t$word_or_eps\n";
$word_or_eps = "<eps>";
$s = $ns;
} else {
if(! $is_silence_word) {
# This is non-deterministic but relatively compact,
# and avoids epsilons.
print "$s\t$loopstate\t$p\t$word_or_eps\t$nosilcost\n";
print "$s\t$silstate\t$p\t$word_or_eps\t$silcost\n";
} else {
# no point putting opt-sil after silence word.
print "$s\t$loopstate\t$p\t$word_or_eps\n";
}
$word_or_eps = "<eps>";
}
}
}
}
print "$loopstate\t0\n"; # final-cost.
}

Просмотреть файл

@ -0,0 +1,37 @@
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# make_phones_symtab.pl < lexicon.txt > phones.txt
while(<>) {
@A = split(" ", $_);
for ($i=2; $i<@A; $i++) {
$P{$A[$i]} = 1; # seen it.
}
}
print "<eps>\t0\n";
$n = 1;
foreach $p (sort keys %P) {
if($p ne "<eps>") {
print "$p\t$n\n";
$n++;
}
}
print "sil\t$n\n";

130
egs/rm/s2/scripts/make_rm_dict.pl Executable file
Просмотреть файл

@ -0,0 +1,130 @@
#!/usr/bin/perl
# Copyright 2010-2011 Yanmin Qian Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This file takes as input the file pcdsril.txt that comes with the RM
# distribution, and creates the dictionary used in RM training.
# make_rm_dct.pl pcdsril.txt > dct.txt
if (@ARGV != 1) {
die "usage: make_rm_dct.pl pcdsril.txt > dct.txt\n";
}
unless (open(IN_FILE, "@ARGV[0]")) {
die ("can't open @ARGV[0]");
}
while ($line = <IN_FILE>)
{
chop($line);
if (($line =~ /^[a-z]/))
{
$line =~ s/\+1//g;
@LineArray = split(/\s+/,$line);
@LineArray[0] = uc(@LineArray[0]);
printf "%-16s", @LineArray[0];
for ($i = 1; $i < @LineArray; $i ++)
{
if (@LineArray[$i] eq 'q')
{}
elsif (@LineArray[$i] eq 'zh')
{
printf "sh ";
}
elsif (@LineArray[$i] eq 'eng')
{
printf "ng ";
}
elsif (@LineArray[$i] eq 'hv')
{
printf "hh ";
}
elsif (@LineArray[$i] eq 'em')
{
printf "m ";
}
elsif (@LineArray[$i] eq 'axr')
{
printf "er ";
}
elsif (@LineArray[$i] eq 'tcl')
{
if (@LineArray[$i+1] ne 't')
{
printf "td ";
}
}
elsif (@LineArray[$i] eq 'dcl')
{
if (@LineArray[$i+1] ne 'd')
{
printf "dd ";
}
}
elsif (@LineArray[$i] eq 'kcl')
{
if (@LineArray[$i+1] ne 'k')
{
printf "kd ";
}
}
elsif (@LineArray[$i] eq 'pcl')
{
if (@LineArray[$i+1] ne 'p')
{
printf "pd ";
}
}
elsif (@LineArray[$i] eq 'bcl')
{
if (@LineArray[$i+1] ne 'b')
{
printf "b ";
}
}
elsif (@LineArray[$i] eq 'gcl')
{
if (@LineArray[$i+1] ne 'g')
{
printf "g ";
}
}
elsif (@LineArray[$i] eq 't')
{
if (@LineArray[$i+1] ne 's')
{
printf "@LineArray[$i] ";
}
else
{
printf "ts ";
$i++;
}
}
else
{
printf "@LineArray[$i] ";
}
}
printf "\n";
}
}
printf "!SIL sil\n";
close(IN_FILE);

119
egs/rm/s2/scripts/make_rm_lm.pl Executable file
Просмотреть файл

@ -0,0 +1,119 @@
#!/usr/bin/perl
# Copyright 2010-2011 Yanmin Qian Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This file takes as input the file wp_gram.txt that comes with the RM
# distribution, and creates the language model as an acceptor in FST form.
# make_rm_lm.pl wp_gram.txt > G.txt
if (@ARGV != 1) {
print "usage: make_rm_lm.pl wp_gram.txt > G.txt\n";
exit(0);
}
unless (open(IN_FILE, "@ARGV[0]")) {
die ("can't open @ARGV[0]");
}
$flag = 0;
$count_wrd = 0;
$cnt_ends = 0;
$init = "";
while ($line = <IN_FILE>)
{
chop($line);
$line =~ s/ //g;
if(($line =~ /^>/))
{
if($flag == 0)
{
$flag = 1;
}
$line =~ s/>//g;
$hashcnt{$init} = $i;
$init = $line;
$i = 0;
$count_wrd++;
@LineArray[$count_wrd - 1] = $init;
$hashwrd{$init} = 0;
}
elsif($flag != 0)
{
$hash{$init}[$i] = $line;
$i++;
if($line =~ /SENTENCE-END/)
{
$cnt_ends++;
}
}
else
{}
}
$hashcnt{$init} = $i;
$num = 0;
$weight = 0;
$init_wrd = "SENTENCE-END";
$hashwrd{$init_wrd} = @LineArray;
for($i = 0; $i < $hashcnt{$init_wrd}; $i++)
{
$weight = -log(1/$hashcnt{$init_wrd});
$hashwrd{$hash{$init_wrd}[$i]} = $i + 1;
print "0 $hashwrd{$hash{$init_wrd}[$i]} $hash{$init_wrd}[$i] $hash{$init_wrd}[$i] $weight\n";
}
$num = $i;
for($i = 0; $i < @LineArray; $i++)
{
if(@LineArray[$i] eq 'SENTENCE-END')
{}
else
{
if($hashwrd{@LineArray[$i]} == 0)
{
$num++;
$hashwrd{@LineArray[$i]} = $num;
}
for($j = 0; $j < $hashcnt{@LineArray[$i]}; $j++)
{
$weight = -log(1/$hashcnt{@LineArray[$i]});
if($hashwrd{$hash{@LineArray[$i]}[$j]} == 0)
{
$num++;
$hashwrd{$hash{@LineArray[$i]}[$j]} = $num;
}
if($hash{@LineArray[$i]}[$j] eq 'SENTENCE-END')
{
print "$hashwrd{@LineArray[$i]} $hashwrd{$hash{@LineArray[$i]}[$j]} <eps> <eps> $weight\n"
}
else
{
print "$hashwrd{@LineArray[$i]} $hashwrd{$hash{@LineArray[$i]}[$j]} $hash{@LineArray[$i]}[$j] $hash{@LineArray[$i]}[$j] $weight\n";
}
}
}
}
print "$hashwrd{$init_wrd} 0\n";
close(IN_FILE);

102
egs/rm/s2/scripts/make_roots.pl Executable file
Просмотреть файл

@ -0,0 +1,102 @@
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Written by Dan Povey 9/21/2010. Apache 2.0 License.
# This version of make_roots.pl is specialized for RM.
# This script creates the file roots.txt which is an input to train-tree.cc. It
# specifies how the trees are built. The input file phone-sets.txt is a partial
# version of roots.txt in which phones are represented by their spelled form, not
# their symbol id's. E.g. at input, phone-sets.txt might contain;
# shared not-split sil
# Any phones not specified in phone-sets.txt but present in phones.txt will
# be given a default treatment. If the --separate option is given, we create
# a separate tree root for each of them, otherwise they are all lumped in one set.
# The arguments shared|not-shared and split|not-split are needed if any
# phones are not specified in phone-sets.txt. What they mean is as follows:
# if shared=="shared" then we share the tree-root between different HMM-positions
# (0,1,2). If split=="split" then we actually do decision tree splitting on
# that root, otherwise we forbid decision-tree splitting. (The main reason we might
# set this to false is for silence when
# we want to ensure that the HMM-positions will remain with a single PDF id.
$separate = 0;
if($ARGV[0] eq "--separate") {
$separate = 1;
shift @ARGV;
}
if(@ARGV != 4) {
die "Usage: make_roots.pl [--separate] phones.txt silence-phone-list[integer,colon-separated] shared|not-shared split|not-split > roots.txt\n";
}
($phonesfile, $silphones, $shared, $split) = @ARGV;
if($shared ne "shared" && $shared ne "not-shared") {
die "Third argument must be \"shared\" or \"not-shared\"\n";
}
if($split ne "split" && $split ne "not-split") {
die "Third argument must be \"split\" or \"not-split\"\n";
}
open(F, "<$phonesfile") || die "Opening file $phonesfile";
while(<F>) {
@A = split(" ", $_);
if(@A != 2) {
die "Bad line in phones symbol file: ".$_;
}
if($A[1] != 0) {
$symbol2id{$A[0]} = $A[1];
$id2symbol{$A[1]} = $A[0];
}
}
if($silphones == ""){
die "Empty silence phone list in make_roots.pl";
}
foreach $silphoneid (split(":", $silphones)) {
defined $id2symbol{$silphoneid} || die "No such silence phone id $silphoneid";
# Give each silence phone its own separate pdfs in each state, but
# no sharing (in this recipe; WSJ is different.. in this recipe there
#is only one silence phone anyway.)
$issil{$silphoneid} = 1;
print "not-shared not-split $silphoneid\n";
}
$idlist = "";
$remaining_phones = "";
if($separate){
foreach $a (keys %id2symbol) {
if(!defined $issil{$a}) {
print "$shared $split $a\n";
}
}
} else {
print "$shared $split ";
foreach $a (keys %id2symbol) {
if(!defined $issil{$a}) {
print "$a ";
}
}
print "\n";
}

Просмотреть файл

@ -0,0 +1,39 @@
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# make_words_symtab.pl < G.txt > words.txt
while(<>) {
@A = split(" ", $_);
if(@A >= 3) {
$W{$A[2]} = 1;
}
}
print "<eps>\t0\n";
$n = 1;
foreach $w (sort keys %W) {
if($w ne "<eps>") {
print "$w\t$n\n";
$n++;
}
}
print "!SIL\t$n\n";

107
egs/rm/s2/scripts/mkgraph.sh Executable file
Просмотреть файл

@ -0,0 +1,107 @@
#!/bin/bash
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
reorder=true # Dan-style, make false for Mirko+Lukas's decoder.
for x in 1 2 3; do
if [ $1 == "--mono" ]; then
monophone_opts="--context-size=1 --central-position=0"
shift;
fi
if [ $1 == "--noreorder" ]; then
reorder=false # we set this for the Kaldi decoder.
shift;
fi
done
if [ $# != 3 ]; then
echo "Usage: scripts/mkgraph.sh <tree> <model> <graphdir>"
exit 1;
fi
if [ -f path.sh ]; then . path.sh; fi
tree=$1
model=$2
dir=$3
mkdir -p $dir
tscale=1.0
loopscale=0.1
fsttablecompose data/L_disambig.fst data/G.fst | fstdeterminizestar --use-log=true | \
fstminimizeencoded > $dir/LG.fst
fstisstochastic $dir/LG.fst || echo "warning: LG not stochastic."
echo "Example string from LG.fst: "
echo
fstrandgen --select=log_prob $dir/LG.fst | fstprint --isymbols=data/phones_disambig.txt --osymbols=data/words.txt -
grep '#' data/phones_disambig.txt | awk '{print $2}' > $dir/disambig_phones.list
fstcomposecontext $monophone_opts \
--read-disambig-syms=$dir/disambig_phones.list \
--write-disambig-syms=$dir/disambig_ilabels.list \
$dir/ilabels < $dir/LG.fst >$dir/CLG.fst
# for debugging:
fstmakecontextsyms data/phones.txt $dir/ilabels > $dir/context_syms.txt
echo "Example string from CLG.fst: "
echo
fstrandgen --select=log_prob $dir/CLG.fst | fstprint --isymbols=$dir/context_syms.txt --osymbols=data/words.txt -
fstisstochastic $dir/CLG.fst || echo "warning: CLG not stochastic."
make-ilabel-transducer --write-disambig-syms=$dir/disambig_ilabels_remapped.list $dir/ilabels $tree $model $dir/ilabels.remapped > $dir/ilabel_map.fst
# Reduce size of CLG by remapping symbols...
fsttablecompose $dir/ilabel_map.fst $dir/CLG.fst | fstdeterminizestar --use-log=true \
| fstminimizeencoded > $dir/CLG2.fst
cat $dir/CLG2.fst | fstisstochastic || echo "warning: CLG2 is not stochastic."
make-h-transducer --disambig-syms-out=$dir/disambig_tstate.list \
--transition-scale=$tscale $dir/ilabels.remapped $tree $model > $dir/Ha.fst
fsttablecompose $dir/Ha.fst $dir/CLG2.fst | fstdeterminizestar --use-log=true \
| fstrmsymbols $dir/disambig_tstate.list | fstrmepslocal | fstminimizeencoded > $dir/HCLGa.fst
fstisstochastic $dir/HCLGa.fst || echo "HCLGa is not stochastic"
add-self-loops --self-loop-scale=$loopscale --reorder=$reorder $model < $dir/HCLGa.fst > $dir/HCLG.fst
if [ $tscale == 1.0 -a $loopscale == 1.0 ]; then
# No point doing this test if transition-scale not 1, as it is bound to fail.
fstisstochastic $dir/HCLG.fst || echo "Final HCLG is not stochastic."
fi
fstisstochastic $dir/HCLG.fst || echo "Final HCLG is not stochastic."
#The next five lines are debug.
# The last two lines of this block print out some alignment info.
fstrandgen --select=log_prob $dir/HCLG.fst | fstprint --osymbols=data/words.txt > $dir/rand.txt
cat $dir/rand.txt | awk 'BEGIN{printf("0 ");} {if(NF>=3 && $3 != 0){ printf ("%d ",$3); }} END {print ""; }' > $dir/rand_align.txt
show-alignments data/phones.txt $model ark:$dir/rand_align.txt
cat $dir/rand.txt | awk ' {if(NF>=4 && $4 != "<eps>"){ printf ("%s ",$4); }} END {print ""; }'

115
egs/rm/s2/scripts/mkgraph_alt.sh Executable file
Просмотреть файл

@ -0,0 +1,115 @@
#!/bin/bash
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This version of mkgraph.sh creates the C fst explicitly.
reorder=true # Dan-style, make false for Mirko+Lukas's decoder.
for x in 1 2 3; do
if [ $1 == "--mono" ]; then
monophone_opts="--context-size=1 --central-position=0"
shift;
fi
if [ $1 == "--noreorder" ]; then
reorder=false # we set this for the Kaldi decoder.
shift;
fi
done
if [ $# != 3 ]; then
echo "Usage: scripts/mkgraph.sh <tree> <model> <graphdir>"
exit 1;
fi
if [ -f path.sh ]; then . path.sh; fi
tree=$1
model=$2
dir=$3
mkdir -p $dir
tscale=1.0
loopscale=0.1
fsttablecompose data/L_disambig.fst data/G.fst | fstdeterminizestar --use-log=true | \
fstminimizeencoded > $dir/LG.fst
fstisstochastic $dir/LG.fst || echo "warning: LG not stochastic."
echo "Example string from LG.fst: "
echo
fstrandgen --select=log_prob $dir/LG.fst | fstprint --isymbols=data/phones_disambig.txt --osymbols=data/words.txt -
grep '#' data/phones_disambig.txt | awk '{print $2}' > $dir/disambig_phones.list
subseq_sym=`tail -1 data/phones_disambig.txt | awk '{print $2+1;}'`
cp data/phones_disambig.txt $dir/phones_disambig_subseq.txt
echo '$' $subseq_sym >> $dir/phones_disambig_subseq.txt
fstmakecontextfst --read-disambig-syms=$dir/disambig_phones.list \
--write-disambig-syms=$dir/disambig_ilabels.list data/phones.txt $subseq_sym \
$dir/ilabels | fstarcsort --sort_type=olabel > $dir/C.fst
fstaddsubsequentialloop $subseq_sym $dir/LG.fst | \
fsttablecompose $dir/C.fst - > $dir/CLG.fst
# for debugging:
fstmakecontextsyms data/phones.txt $dir/ilabels > $dir/context_syms.txt
echo "Example string from CLG.fst: "
echo
fstrandgen --select=log_prob $dir/CLG.fst | fstprint --isymbols=$dir/context_syms.txt --osymbols=data/words.txt -
fstisstochastic $dir/CLG.fst || echo "warning: CLG not stochastic."
make-ilabel-transducer --write-disambig-syms=$dir/disambig_ilabels_remapped.list $dir/ilabels $tree $model $dir/ilabels.remapped > $dir/ilabel_map.fst
# Reduce size of CLG by remapping symbols...
fstcompose $dir/ilabel_map.fst $dir/CLG.fst | fstdeterminizestar --use-log=true \
| fstminimizeencoded > $dir/CLG2.fst
cat $dir/CLG2.fst | fstisstochastic || echo "warning: CLG2 is not stochastic."
make-h-transducer --disambig-syms-out=$dir/disambig_tstate.list \
--transition-scale=$tscale $dir/ilabels.remapped $tree $model > $dir/Ha.fst
fsttablecompose $dir/Ha.fst $dir/CLG2.fst | fstdeterminizestar --use-log=true \
| fstrmsymbols $dir/disambig_tstate.list | fstrmepslocal | fstminimizeencoded > $dir/HCLGa.fst
fstisstochastic $dir/HCLGa.fst || echo "HCLGa is not stochastic"
add-self-loops --self-loop-scale=$loopscale --reorder=$reorder $model < $dir/HCLGa.fst > $dir/HCLG.fst
if [ $tscale == 1.0 -a $loopscale == 1.0 ]; then
# No point doing this test if transition-scale not 1, as it is bound to fail.
fstisstochastic $dir/HCLG.fst || echo "Final HCLG is not stochastic."
fi
fstisstochastic $dir/HCLG.fst || echo "Final HCLG is not stochastic."
#The next five lines are debug.
# The last two lines of this block print out some alignment info.
fstrandgen --select=log_prob $dir/HCLG.fst | fstprint --osymbols=data/words.txt > $dir/rand.txt
cat $dir/rand.txt | awk 'BEGIN{printf("0 ");} {if(NF>=3 && $3 != 0){ printf ("%d ",$3); }} END {print ""; }' > $dir/rand_align.txt
show-alignments data/phones.txt $model ark:$dir/rand_align.txt
cat $dir/rand.txt | awk ' {if(NF>=4 && $4 != "<eps>"){ printf ("%s ",$4); }} END {print ""; }'

Просмотреть файл

@ -0,0 +1,47 @@
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script is part of a diagnostic step when using exponential transforms.
$map=$ARGV[0]; open(M,"<$map")||die "opening map file $map";
while(<M>){ @A=split(" ",$_); $map{$A[0]} = $A[1]; }
while(<STDIN>){
($spk,$warp)=split(" ",$_);
$class = int($class/2);
defined $map{$spk} || die "No gender info for speaker $spk";
$warps{$map{$spk}} = $warps{$map{$spk}} . "$warp ";
}
@K = sort keys %warps;
@K==2||die "wrong number of keys [empty warps file?]";
foreach $k ( @K ) {
$s = join(" ", sort { $a <=> $b } ( split(" ", $warps{$k}) )) ;
print "$k = [ $s ];\n";
}
# f,m may be reversed below; doesnt matter.
foreach $w ( split(" ", $warps{$K[0]}) ) {
$nf += 1; $sumf += $w; $sumf2 += $w*$w;
}
foreach $w ( split(" ", $warps{$K[1]}) ) {
$nm += 1; $summ += $w; $summ2 += $w*$w;
}
$sumf /= $nf; $sumf2 /= $nf;
$summ /= $nm; $summ2 /= $nm;
$sumf2 -= $sumf*$sumf;
$summ2 -= $summ*$summ;
$avgwithin = 0.5*($sumf2+$summ2 );
$diff = abs($sumf - $summ) / sqrt($avgwithin);
print "% class separation is $diff\n";

57
egs/rm/s2/scripts/silphones.pl Executable file
Просмотреть файл

@ -0,0 +1,57 @@
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# creates integer lists of silence and non-silence phones in files,
# e.g. silphones.csl="1:2:3 \n"
# and nonsilphones.csl="4:5:6:7:...:24\n";
if(@ARGV != 4) {
die "Usage: silphones.pl phones.txt \"sil1 sil2 sil3\" silphones.csl nonsilphones.csl";
}
($symtab, $sillist, $silphones, $nonsilphones) = @ARGV;
open(S,"<$symtab") || die "Opening symbol table $symtab";
foreach $s (split(" ", $sillist)) {
$issil{$s} = 1;
}
@sil = ();
@nonsil = ();
while(<S>){
@A = split(" ", $_);
@A == 2 || die "Bad line $_ in phone-symbol-table file $symtab";
($sym, $int) = @A;
if($int != 0) {
if($issil{$sym}) { push @sil, $int; $seensil{$sym}=1; }
else { push @nonsil, $int; }
}
}
foreach $k(keys %issil) {
if(!$seensil{$k}) { die "No such silence phone $k"; }
}
open(F, ">$silphones") || die "opening silphones file $silphones";
open(G, ">$nonsilphones") || die "opening nonsilphones file $nonsilphones";
print F join(":", @sil) . "\n";
print G join(":", @nonsil) . "\n";
close(F);
close(G);
if(@sil == 0) { print STDERR "Warning: silphones.pl no silence phones.\n" }
if(@nonsil == 0) { print STDERR "Warning: silphones.pl no non-silence phones.\n" }

Просмотреть файл

@ -0,0 +1,27 @@
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
while(<>){
@A = split(" ", $_);
@A > 1 || die "Invalid line in spk2utt file: $_";
$s = shift @A;
foreach $u ( @A ) {
print "$u $s\n";
}
}

181
egs/rm/s2/scripts/split_scp.pl Executable file
Просмотреть файл

@ -0,0 +1,181 @@
#!/usr/bin/perl -w
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program splits up any kind of .scp or archive-type file.
# If there is no utt2spk option it will work on any text file and
# will split it up with an approximately equal number of lines in
# each but.
# With the --utt2spk option it will work on anything that has the
# utterance-id as the first entry on each line; the utt2spk file is
# of the form "utterance speaker" (on each line).
# It splits it into equal size chunks as far as it can. If you use
# the utt2spk option it will make sure these chunks coincide with
# speaker boundaries. In this case, if there are more chunks
# than speakers (and in some other circumstances), some of the
# resulting chunks will be empty and it
# will print a warning.
# You will normally call this like:
# split_scp.pl scp scp.1 scp.2 scp.3 ...
# or
# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
# Note that you can use this script to split the utt2spk file itself,
# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
if(@ARGV < 2 ) {
die "Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ... ";
}
if($ARGV[0] =~ m:^-:) {
# Everything inside this block
# corresponds to what we do when the --utt2spk option is used.
$opt = shift @ARGV;
@A = split("=", $opt);
if(@A != 2 || $A[0] ne "--utt2spk") {
die "split_scp.pl: invalid option $ARGV[0]";
}
$utt2spk_file = $A[1];
open(U, "<$utt2spk_file") || die "Failed to open utt2spk file $utt2spk_file";
while(<U>) {
@A = split;
@A == 2 || die "Bad line $_ in utt2spk file $utt2spk_file";
($u,$s) = @A;
$utt2spk{$u} = $s;
}
$inscp = shift @ARGV;
open(I, "<$inscp") || die "Opening input scp file $inscp";
@spkrs = ();
while(<I>) {
@A = split;
if(@A == 0) { die "Empty or space-only line in scp file $inscp"; }
$u = $A[0];
$s = $utt2spk{$u};
if(!defined $s) { die "No such utterance $u in utt2spk file $utt2spk_file"; }
if(!defined $spk_count{$s}) {
push @spkrs, $s;
$spk_count{$s} = 0;
$spk_data{$s} = "";
}
$spk_count{$s}++;
$spk_data{$s} = $spk_data{$s} . $_;
}
# Now split as equally as possible ..
# First allocate spks to files by given approximately
# equal #spks.
$numspks = @spkrs; # number of speakers.
$numscps = @ARGV; # number of output files.
$spksperscp = int( ($numspks+($numscps-1)) / $numscps); # the +$(numscps-1) forces rounding up.
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scparray[$scpidx] = []; # [] is array reference.
for($n = $spksperscp * $scpidx;
$n < $numspks && $n < $spksperscp*($scpidx+1);
$n++) {
$spk = $spkrs[$n];
push @{$scparray[$scpidx]}, $spk;
$scpcount[$scpidx] += $spk_count{$spk};
}
}
# Now will try to reassign beginning + ending speakers
# to different scp's and see if it gets more balanced.
# Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
# We can show that if considering changing just 2 scp's, we minimize
# this by minimizing the squared difference in sizes. This is
# equivalent to minimizing the absolute difference in sizes. This
# shows this method is bound to converge.
$changed = 1;
while($changed) {
$changed = 0;
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
# First try to reassign ending spk of this scp.
if($scpidx < $numscps-1) {
$sz = @{$scparray[$scpidx]};
if($sz > 0) {
$spk = $scparray[$scpidx]->[$sz-1];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx];
$nutt2 = $scpcount[$scpidx+1];
if( abs( ($nutt2+$count) - ($nutt1-$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx+1] += $count;
$scpcount[$scpidx] -= $count;
pop @{$scparray[$scpidx]};
unshift @{$scparray[$scpidx+1]}, $spk;
$changed = 1;
}
}
}
if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
$spk = $scparray[$scpidx]->[0];
$count = $spk_count{$spk};
$nutt1 = $scpcount[$scpidx-1];
$nutt2 = $scpcount[$scpidx];
if( abs( ($nutt2-$count) - ($nutt1+$count))
< abs($nutt2 - $nutt1)) { # Would decrease
# size-diff by reassigning spk...
$scpcount[$scpidx-1] += $count;
$scpcount[$scpidx] -= $count;
shift @{$scparray[$scpidx]};
push @{$scparray[$scpidx-1]}, $spk;
$changed = 1;
}
}
}
}
# Now print out the files...
for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
$scpfn = $ARGV[$scpidx];
open(F, ">$scpfn") || die "Could not open scp file $scpfn for writing.";
$count = 0;
if(@{$scparray[$scpidx]} == 0) {
print STDERR "Warning: split_scp.pl producing empty .scp file $scpfn (too many splits and too few speakers?)";
}
foreach $spk ( @{$scparray[$scpidx]} ) {
print F $spk_data{$spk};
$count += $spk_count{$spk};
}
if($count != $scpcount[$scpidx]) { die "Count mismatch [code error]"; }
close(F);
}
} else {
# This block is the "normal" case where there is no --utt2spk
# option and we just break into equal size chunks.
$inscp = shift @ARGV;
open(I, "<$inscp") || die "Opening input scp file $inscp";
$numscps = @ARGV; # size of array.
@F = ();
while(<I>) {
push @F, $_;
}
$numlines = @F;
if($numlines == 0) {
print STDERR "split_scp.pl: warning: empty input scp file $inscp";
}
$linesperscp = int( ($numlines+($numscps-1)) / $numscps); # the +$(numscps-1) forces rounding up.
# [just doing int() rounds down].
for($scpidx = 0; $scpidx < @ARGV; $scpidx++) {
$scpfile = $ARGV[$scpidx];
open(O, ">$scpfile") || die "Opening output scp file $scpfile";
for($n = $linesperscp * $scpidx; $n < $numlines && $n < $linesperscp*($scpidx+1); $n++) {
print O $F[$n];
}
close(O) || die "Closing scp file $scpfile";
}
}

59
egs/rm/s2/scripts/subset_scp.pl Executable file
Просмотреть файл

@ -0,0 +1,59 @@
#!/usr/bin/perl -w
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program selects a subset of N elements in the scp.
# It selects them evenly from throughout the scp, in order to
# avoid selecting too many from the same speaker.
# It prints them on the standard output.
if(@ARGV < 2 ) {
die "Usage: subset_scp.pl N in.scp ";
}
$N = shift @ARGV;
if($N == 0) {
die "First command-line parameter to subset_scp.pl must be an integer, got \"$N\"";
}
$inscp = shift @ARGV;
open(I, "<$inscp") || die "Opening input scp file $inscp";
@F = ();
while(<I>) {
push @F, $_;
}
$numlines = @F;
if($N > $numlines) {
die "You requested from subset_scp.pl more elements than available: $N > $numlines";
}
sub select_n {
my ($start,$end,$num_needed) = @_;
my $diff = $end - $start;
if($num_needed > $diff) { die "select_n: code error"; }
if($diff == 1 ) {
if($num_needed > 0) {
print $F[$start];
}
} else {
my $halfdiff = int($diff/2);
my $halfneeded = int($num_needed/2);
select_n($start, $start+$halfdiff, $halfneeded);
select_n($start+$halfdiff, $end, $num_needed - $halfneeded);
}
}
select_n(0, $numlines, $N);

59
egs/rm/s2/scripts/sym2int.pl Executable file
Просмотреть файл

@ -0,0 +1,59 @@
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
$ignore_oov = 0;
$ignore_first_field = 0;
for($x = 0; $x < 2; $x++) {
if($ARGV[0] eq "--ignore-oov") { $ignore_oov = 1; shift @ARGV; }
if($ARGV[0] eq "--ignore-first-field") { $ignore_first_field = 1; shift @ARGV; }
}
$symtab = shift @ARGV;
if(!defined $symtab) {
die "Usage: sym2int.pl symtab [input transcriptions] > output transcriptions\n";
}
open(F, "<$symtab") || die "Error opening symbol table file $symtab";
while(<F>) {
@A = split(" ", $_);
@A == 2 || die "bad line in symbol table file: $_";
$sym2int{$A[0]} = $A[1] + 0;
}
while(<>) {
@A = split(" ", $_);
if(@A == 0) {
die "Empty line in transcriptions input.";
}
if($ignore_first_field) {
$key = shift @A;
print $key . " ";
}
foreach $a (@A) {
$i = $sym2int{$a};
if(!defined ($i)) {
if($ignore_oov) {
print $a . " " ;
} else {
die "sym2int.pl: undefined symbol $a\n";
}
}
print $i . " ";
}
print "\n";
}

Просмотреть файл

@ -0,0 +1,33 @@
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
while(<>){
@A = split(" ", $_);
@A == 2 || die "Invalid line in utt2spk file: $_";
($u,$s) = @A;
if(!$seen_spk{$s}) {
$seen_spk{$s} = 1;
push @spklist, $s;
}
$uttlist{$s} = $uttlist{$s} . "$u ";
}
foreach $s (@spklist) {
$l = $uttlist{$s};
$l =~ s: $::; # remove trailing space.
print "$s $l\n";
}

45
egs/rm/s2/steps/decode_mono.sh Executable file
Просмотреть файл

@ -0,0 +1,45 @@
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Monophone decoding script.
if [ -f path.sh ]; then . path.sh; fi
dir=exp/decode_mono
tree=exp/mono/tree
mkdir -p $dir
model=exp/mono/final.mdl
graphdir=exp/graph_mono
scripts/mkgraph.sh --mono $tree $model $graphdir
for test in mar87 oct87 feb89 oct89 feb91 sep92; do
(
feats="ark:add-deltas --print-args=false scp:data/test_${test}.scp ark:- |"
gmm-decode-faster --beam=20.0 --acoustic-scale=0.08333 --word-symbol-table=data/words.txt $model $graphdir/HCLG.fst "$feats" ark,t:$dir/test_${test}.tra ark,t:$dir/test_${test}.ali 2> $dir/decode_${test}.log
# the ,p option lets it score partial output without dying..
scripts/sym2int.pl --ignore-first-field data/words.txt data_prep/test_${test}_trans.txt | \
compute-wer --mode=present ark:- ark,p:$dir/test_${test}.tra >& $dir/wer_${test}
) &
done
wait
grep WER $dir/wer_* | \
awk '{n=n+$4; d=d+$6} END{ printf("Average WER is %f (%d / %d) \n", (100.0*n)/d, n, d); }' \
> $dir/wer

Просмотреть файл

@ -0,0 +1,44 @@
#!/bin/bash
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# To be run from .. (one directory up from here)
if [ $# != 1 ]; then
echo "usage: make_mfcc_test.sh <abs-path-to-tmpdir>"
exit 1;
fi
if [ -f path.sh ]; then . path.sh; fi
dir=exp/make_mfcc
mkdir -p $dir
root_out=$1
mkdir -p $root_out
for test in mar87 oct87 feb89 oct89 feb91 sep92; do
scpin=data_prep/test_${test}_wav.scp
# Making it like this so it works for others on the BUT filesystem.
# It will generate the correct scp file without running the feature extraction.
log=$dir/make_mfcc_test_${test}.log
(
compute-mfcc-feats --verbose=2 --config=conf/mfcc.conf scp:$scpin ark,scp:$root_out/test_${test}_raw_mfcc.ark,$root_out/test_${test}_raw_mfcc.scp 2> $log || tail $log
cp $root_out/test_${test}_raw_mfcc.scp data/test_${test}.scp
) &
done
wait
echo "If the above produced no output on the screen, it succeeded."

Просмотреть файл

@ -0,0 +1,43 @@
#!/bin/bash
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# To be run from .. (one directory up from here)
if [ $# != 1 ]; then
echo "usage: make_mfcc_train.sh <abs-path-to-tmpdir>";
exit 1;
fi
if [ -f path.sh ]; then . path.sh; fi
scpin=data_prep/train_wav.scp
dir=exp/make_mfcc
mkdir -p $dir
root_out=$1
mkdir -p $root_out
scripts/split_scp.pl $scpin $dir/train_wav{1,2,3,4}.scp
for n in 1 2 3 4; do # Use 4 CPUs
log=$dir/make_mfcc_train.$n.log
compute-mfcc-feats --verbose=2 --config=conf/mfcc.conf scp:$dir/train_wav${n}.scp ark,scp:$root_out/train_raw_mfcc${n}.ark,$root_out/train_raw_mfcc${n}.scp 2> $log || tail $log &
done
wait;
cat $root_out/train_raw_mfcc{1,2,3,4}.scp > data/train.scp
echo "If the above produced no output on the screen, it succeeded."

Просмотреть файл

@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# The output of this script is the symbol tables data/{words.txt,phones.txt},
# and the grammars and lexicons data/{L,G}{,_disambig}.fst
# To be run from ..
if [ -f path.sh ]; then . path.sh; fi
cp data_prep/G.txt data/
scripts/make_words_symtab.pl < data/G.txt > data/words.txt
cp data_prep/lexicon.txt data/
scripts/make_phones_symtab.pl < data/lexicon.txt > data/phones.txt
silphones="sil"; # This would in general be a space-separated list of all silence phones. E.g. "sil vn"
# Generate colon-separated lists of silence and non-silence phones.
scripts/silphones.pl data/phones.txt "$silphones" data/silphones.csl data/nonsilphones.csl
ndisambig=`scripts/add_lex_disambig.pl data/lexicon.txt data/lexicon_disambig.txt`
scripts/add_disambig.pl data/phones.txt $ndisambig > data/phones_disambig.txt
# Create train transcripts in integer format:
cat data_prep/train_trans.txt | \
scripts/sym2int.pl --ignore-first-field data/words.txt > data/train.tra
# Get lexicon in FST format.
# silprob = 0.5: same prob as word.
scripts/make_lexicon_fst.pl data/lexicon.txt 0.5 sil | fstcompile --isymbols=data/phones.txt --osymbols=data/words.txt --keep_isymbols=false --keep_osymbols=false | fstarcsort --sort_type=olabel > data/L.fst
scripts/make_lexicon_fst.pl data/lexicon_disambig.txt 0.5 sil | fstcompile --isymbols=data/phones_disambig.txt --osymbols=data/words.txt --keep_isymbols=false --keep_osymbols=false | fstarcsort --sort_type=olabel > data/L_disambig.fst
fstcompile --isymbols=data/words.txt --osymbols=data/words.txt --keep_isymbols=false --keep_osymbols=false data/G.txt > data/G.fst
# Checking that G is stochastic [note, it wouldn't be for an Arpa]
fstisstochastic data/G.fst || echo Error
# Checking that disambiguated lexicon times G is determinizable
fsttablecompose data/L_disambig.fst data/G.fst | fstdeterminize >/dev/null || echo Error
# Checking that LG is stochastic:
fsttablecompose data/L.fst data/G.fst | fstisstochastic || echo Error
## Check lexicon.
## just have a look and make sure it seems sane.
fstprint --isymbols=data/phones.txt --osymbols=data/words.txt data/L.fst | head

92
egs/rm/s2/steps/train_mono.sh Executable file
Просмотреть файл

@ -0,0 +1,92 @@
#!/bin/bash
# Copyright 2010-2011 Microsoft Corporation Arnab Ghoshal
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# To be run from ..
if [ -f path.sh ]; then . path.sh; fi
# Train the monophone on a subset-- no point using all the data.
dir=exp/mono
n=1000
feats="ark:add-deltas --print-args=false scp:$dir/train.scp ark:- |"
# need to quote when passing as an argument, as in "$feats",
# since it has spaces in it.
scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1"
numiters=30 # Number of iterations of training
maxiterinc=20 # Last iter to increase #Gauss on.
numgauss=250 # Initial num-Gauss (must be more than #states=3*phones).
totgauss=1000 # Target #Gaussians.
incgauss=$[($totgauss-$numgauss)/$maxiterinc] # per-iter increment for #Gauss
realign_iters="1 2 3 4 5 6 7 8 9 10 12 15 20 25";
mkdir -p $dir
scripts/subset_scp.pl $n data/train.scp > $dir/train.scp
silphones=`cat data/silphones.csl | sed 's/:/ /g'`
nonsilphones=`cat data/nonsilphones.csl | sed 's/:/ /g'`
cat conf/topo.proto | sed "s:NONSILENCEPHONES:$nonsilphones:" | sed "s:SILENCEPHONES:$silphones:" > $dir/topo
gmm-init-mono '--train-feats=ark:head -10 data/train.scp | add-deltas scp:- ark:- |' $dir/topo 39 $dir/0.mdl $dir/tree 2> $dir/init.out || exit 1;
echo "Compiling training graphs"
compile-train-graphs $dir/tree $dir/0.mdl data/L.fst \
"ark:scripts/subset_scp.pl $n data/train.tra|" \
"ark:|gzip -c >$dir/graphs.fsts.gz" 2>$dir/compile_graphs.log || exit 1
echo Pass 0
align-equal-compiled "ark:gunzip -c $dir/graphs.fsts.gz|" "$feats" \
ark,t,f:- 2>$dir/align.0.log | \
gmm-acc-stats-ali --binary=true $dir/0.mdl "$feats" ark:- \
$dir/0.acc 2> $dir/acc.0.log || exit 1;
# In the following steps, the --min-gaussian-occupancy=3 option is important, otherwise
# we fail to est "rare" phones and later on, they never align properly.
gmm-est --min-gaussian-occupancy=3 --mix-up=$numgauss \
$dir/0.mdl $dir/0.acc $dir/1.mdl 2> $dir/update.0.log || exit 1;
rm $dir/0.acc
beam=4 # will change to 8 below after 1st pass
x=1
while [ $x -lt $numiters ]; do
echo "Pass $x"
if echo $realign_iters | grep -w $x >/dev/null; then
echo "Aligning data"
gmm-align-compiled $scale_opts --beam=$beam --retry-beam=$[$beam*4] $dir/$x.mdl \
"ark:gunzip -c $dir/graphs.fsts.gz|" "$feats" t,ark:$dir/cur.ali \
2> $dir/align.$x.log || exit 1;
fi
gmm-acc-stats-ali --binary=false $dir/$x.mdl "$feats" ark:$dir/cur.ali $dir/$x.acc 2> $dir/acc.$x.log || exit 1;
gmm-est --mix-up=$numgauss $dir/$x.mdl $dir/$x.acc $dir/$[$x+1].mdl 2> $dir/update.$x.log || exit 1;
rm $dir/$x.mdl $dir/$x.acc
if [ $x -le $maxiterinc ]; then
numgauss=$[$numgauss+$incgauss];
fi
beam=8
x=$[$x+1]
done
( cd $dir; rm final.mdl 2>/dev/null; ln -s $x.mdl final.mdl )
# example of showing the alignments:
# show-alignments data/phones.txt $dir/30.mdl ark:$dir/cur.ali | head -4

92
egs/rm/s2/steps/train_nnet.sh Executable file
Просмотреть файл

@ -0,0 +1,92 @@
#!/bin/bash
# To be run from ..
if [ -f path.sh ]; then . path.sh; fi
dir=exp/nnet
mkdir -p $dir/{log,nnet}
#use following features and alignments
cp exp/mono/train.scp exp/mono/cur.ali $dir
head -n 800 $dir/train.scp > $dir/train.scp.tr
tail -n 200 $dir/train.scp > $dir/train.scp.cv
feats="ark:add-deltas --print-args=false scp:$dir/train.scp ark:- |"
feats_tr="ark:add-deltas --print-args=false scp:$dir/train.scp.tr ark:- |"
feats_cv="ark:add-deltas --print-args=false scp:$dir/train.scp.cv ark:- |"
labels="ark:$dir/cur.ali"
#compute per utterance CMVN
cmvn="ark:$dir/cmvn.ark"
compute-cmvn-stats "$feats" $cmvn
feats_tr="$feats_tr apply-cmvn --print-args=false --norm-vars=true $cmvn ark:- ark:- |"
feats_cv="$feats_cv apply-cmvn --print-args=false --norm-vars=true $cmvn ark:- ark:- |"
#initialize the nnet
mlp_init=$dir/nnet.init
scripts/gen_mlp_init.py --dim=39:512:300 --gauss --negbias > $mlp_init
#global config for trainig
max_iters=20
start_halving_inc=0.5
end_halving_inc=0.1
lrate=0.001
nnet-train-xent-hardlab-perutt --cross-validate=true $mlp_init "$feats_cv" "$labels" &> $dir/log/prerun.log
if [ $? != 0 ]; then cat $dir/log/prerun.log; exit 1; fi
acc=$(cat $dir/log/prerun.log | grep Xent | tail -n 1 | cut -d'[' -f 2 | cut -d'%' -f 1)
echo CROSSVAL PRERUN ACCURACY $acc
mlp_best=$mlp_init
mlp_base=${mlp_init##*/}; mlp_base=${mlp_base%.*}
halving=0
for iter in $(seq -w $max_iters); do
mlp_next=$dir/nnet/${mlp_base}_iter${iter}
nnet-train-xent-hardlab-perutt --learn-rate=$lrate $mlp_best "$feats_tr" "$labels" $mlp_next &> $dir/log/iter$iter.log
if [ $? != 0 ]; then cat $dir/log/iter$iter.log; exit 1; fi
tr_acc=$(cat $dir/log/iter$iter.log | grep Xent | tail -n 1 | cut -d'[' -f 2 | cut -d'%' -f 1)
echo TRAIN ITERATION $iter ACCURACY $tr_acc LRATE $lrate
nnet-train-xent-hardlab-perutt --cross-validate=true $mlp_next "$feats_cv" "$labels" 1>>$dir/log/iter$iter.log 2>>$dir/log/iter$iter.log
if [ $? != 0 ]; then cat $dir/log/iter$iter.log; exit 1; fi
#accept or reject new parameters
acc_new=$(cat $dir/log/iter$iter.log | grep Xent | tail -n 1 | cut -d'[' -f 2 | cut -d'%' -f 1)
echo CROSSVAL ITERATION $iter ACCURACY $acc_new
acc_prev=$acc
if [ 1 == $(awk 'BEGIN{print('$acc_new' > '$acc')}') ]; then
acc=$acc_new
mlp_best=$dir/nnet/$mlp_base.iter${iter}_tr$(printf "%.5g" $tr_acc)_cv$(printf "%.5g" $acc_new)
mv $mlp_next $mlp_best
echo nnet $mlp_best accepted
else
mlp_reject=$dir/nnet/$mlp_base.iter${iter}_tr$(printf "%.5g" $tr_acc)_cv$(printf "%.5g" $acc_new)
mv $mlp_next $mlp_reject
echo nnet $mlp_reject rejected
fi
#stopping criterion
if [[ 1 == $halving && 1 == $(awk 'BEGIN{print('$acc' < '$acc_prev'+'$end_halving_inc')}') ]]; then
echo finished, too small improvement $(awk 'BEGIN{print('$acc'-'$acc_prev')}')
break
fi
#start annealing when improvement is low
if [ 1 == $(awk 'BEGIN{print('$acc' < '$acc_prev'+'$start_halving_inc')}') ]; then
halving=1
fi
#do annealing
if [ 1 == $halving ]; then
lrate=$(awk 'BEGIN{print('$lrate'*0.5)}')
fi
done
if [ $mlp_best != $mlp_init ]; then
iter=$(echo $mlp_best | sed 's/^.*iter\([0-9][0-9]*\).*$/\1/')
fi
mlp_final=$dir/${mlp_base}_final_iter${iter:-0}_acc${acc}
cp $mlp_best $mlp_final
echo final network $mlp_final

Просмотреть файл

@ -7,7 +7,8 @@
SUBDIRS = base matrix util feat tree optimization gmm transform sgmm \
fstext hmm lm decoder \
bin fstbin gmmbin fgmmbin sgmmbin featbin
bin fstbin gmmbin fgmmbin sgmmbin featbin \
nnet nnetbin
all: $(SUBDIRS)
echo Done

47
src/nnet/Makefile Normal file
Просмотреть файл

@ -0,0 +1,47 @@
all:
include ../kaldi.mk
TESTFILES = nnet-test
OBJFILES = nnet-nnet.o nnet-component.o nnet-loss.o
LIBFILE = kaldi-nnet.a
all: $(LIBFILE) $(TESTFILES)
$(LIBFILE): $(OBJFILES)
$(AR) -cru $(LIBFILE) $(OBJFILES)
$(RANLIB) $(LIBFILE)
$(TESTFILES): $(LIBFILE) ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../util/kaldi-util.a
# Rule below would expand to, e.g.:
# ../base/kaldi-base.a:
# make -c ../base kaldi-base.a
# -c option to make is same as changing directory.
%.a:
$(MAKE) -C ${@D} ${@F}
clean:
rm *.o *.a $(TESTFILES)
test: $(TESTFILES)
for x in $(TESTFILES); do ./$$x >&/dev/null || (echo "***test $$x failed***"; exit 1); done
echo Tests succeeded
.valgrind: $(TESTFILES)
depend:
-$(CXX) -M $(CXXFLAGS) *.cc > .depend.mk
# removing automatic making of "depend" as it's quite slow.
#.depend.mk: depend
-include .depend.mk

Просмотреть файл

@ -0,0 +1,94 @@
// nnet/nnet-activation.h
// Copyright 2011 Karel Vesely
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_NNET_ACTIVATION_H
#define KALDI_NNET_ACTIVATION_H
#include "nnet/nnet-component.h"
namespace kaldi {
class Sigmoid : public Component {
public:
Sigmoid(MatrixIndexT dim_in, MatrixIndexT dim_out, Nnet* nnet)
: Component(dim_in, dim_out, nnet)
{ }
~Sigmoid()
{ }
ComponentType GetType() const {
return kSigmoid;
}
void PropagateFnc(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) {
//y = 1/(1+e^-x)
for(MatrixIndexT r=0; r<out->NumRows(); r++) {
for(MatrixIndexT c=0; c<out->NumCols(); c++) {
(*out)(r,c) = 1.0/(1.0+exp(-in(r,c)));
}
}
}
void BackpropagateFnc(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err) {
//ey = y(1-y)ex
const Matrix<BaseFloat>& y = nnet_->PropagateBuffer()[nnet_->IndexOfLayer(*this)+1];
for(MatrixIndexT r=0; r<out_err->NumRows(); r++) {
for(MatrixIndexT c=0; c<out_err->NumCols(); c++) {
(*out_err)(r,c) = y(r,c)*(1.0-y(r,c))*in_err(r,c);
}
}
}
};
class Softmax : public Component {
public:
Softmax(MatrixIndexT dim_in, MatrixIndexT dim_out, Nnet* nnet)
: Component(dim_in, dim_out, nnet)
{ }
~Softmax()
{ }
ComponentType GetType() const {
return kSoftmax;
}
void PropagateFnc(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) {
//y = e^x_j/sum_j(e^x_j)
out->CopyFromMat(in);
for(MatrixIndexT r=0; r<out->NumRows(); r++) {
out->Row(r).ApplySoftMax();
}
}
void BackpropagateFnc(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err) {
//simply copy the error
//(ie. assume crossentropy error function,
// while in_err contains (net_output-target) :
// this is already derivative of the error with
// respect to activations of last layer neurons)
out_err->CopyFromMat(in_err);
}
};
} // namespace
#endif

Просмотреть файл

@ -0,0 +1,127 @@
// nnet/nnet-biasedlinearity.h
// Copyright 2011 Karel Vesely
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_NNET_BIASEDLINEARITY_H
#define KALDI_NNET_BIASEDLINEARITY_H
#include "nnet/nnet-component.h"
namespace kaldi {
class BiasedLinearity : public UpdatableComponent {
public:
BiasedLinearity(MatrixIndexT dim_in, MatrixIndexT dim_out, Nnet* nnet)
: UpdatableComponent(dim_in, dim_out, nnet),
linearity_(dim_out,dim_in), bias_(dim_out),
linearity_corr_(dim_out,dim_in), bias_corr_(dim_out)
{ }
~BiasedLinearity()
{ }
ComponentType GetType() const {
return kBiasedLinearity;
}
void ReadData(std::istream& is, bool binary) {
linearity_.Read(is,binary);
bias_.Read(is,binary);
KALDI_ASSERT(linearity_.NumRows() == output_dim_);
KALDI_ASSERT(linearity_.NumCols() == input_dim_);
KALDI_ASSERT(bias_.Dim() == output_dim_);
}
void WriteData(std::ostream& os, bool binary) const {
linearity_.Write(os,binary);
bias_.Write(os,binary);
}
void PropagateFnc(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) {
//precopy bias
for (MatrixIndexT i=0; i<out->NumRows(); i++) {
out->CopyRowFromVec(bias_,i);
}
//multiply by weights^t
out->AddMatMat(1.0,in,kNoTrans,linearity_,kTrans,1.0);
}
void BackpropagateFnc(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err) {
//multiply error by weights
out_err->AddMatMat(1.0,in_err,kNoTrans,linearity_,kNoTrans,0.0);
}
void Update(const Matrix<BaseFloat>& input, const Matrix<BaseFloat>& err) {
//compute gradient
linearity_corr_.AddMatMat(1.0,err,kTrans,input,kNoTrans,momentum_);
bias_corr_.Scale(momentum_);
bias_corr_.AddRowSumMat(err);
//l2 regularization
if(l2_penalty_ != 0.0) {
linearity_.AddMat(-learn_rate_*l2_penalty_*input.NumRows(),linearity_);
}
//l1 regularization
if(l1_penalty_ != 0.0) {
BaseFloat l1 = learn_rate_*input.NumRows()*l1_penalty_;
for(MatrixIndexT r=0; r<linearity_.NumRows(); r++) {
for(MatrixIndexT c=0; c<linearity_.NumCols(); c++) {
if(linearity_(r,c)==0.0) continue; //skip L1 if zero weight!
BaseFloat l1sign = l1;
if(linearity_(r,c) < 0.0)
l1sign = -l1;
BaseFloat before = linearity_(r,c);
BaseFloat after = linearity_(r,c)-learn_rate_*linearity_corr_(r,c)-l1sign;
if((after > 0.0) ^ (before > 0.0)) {
linearity_(r,c) = 0.0;
linearity_corr_(r,c) = 0.0;
} else {
linearity_(r,c) -= l1sign;
}
}
}
}
//update
linearity_.AddMat(-learn_rate_,linearity_corr_);
bias_.AddVec(-learn_rate_,bias_corr_);
/*
std::cout <<"I"<< input.Row(0);
std::cout <<"E"<< err.Row(0);
std::cout <<"CORL"<< linearity_corr_.Row(0);
std::cout <<"CORB"<< bias_corr_;
std::cout <<"L"<< linearity_.Row(0);
std::cout <<"B"<< bias_;
std::cout << "\n";
*/
//std::cout << l1_penalty_ << l2_penalty_ << momentum_ << learn_rate_ << "\n";
}
private:
Matrix<BaseFloat> linearity_;
Vector<BaseFloat> bias_;
Matrix<BaseFloat> linearity_corr_;
Vector<BaseFloat> bias_corr_;
};
} //namespace
#endif

Просмотреть файл

@ -0,0 +1,98 @@
// nnet/nnet-component.h
// Copyright 2011 Karel Vesely
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "nnet/nnet-component.h"
#include "nnet/nnet-nnet.h"
#include "nnet/nnet-activation.h"
#include "nnet/nnet-biasedlinearity.h"
namespace kaldi {
const struct Component::key_value Component::kMarkerMap[] = {
{ Component::kBiasedLinearity,"<biasedlinearity>" },
{ Component::kSigmoid,"<sigmoid>" },
{ Component::kSoftmax,"<softmax>" }
};
const char* Component::TypeToMarker(ComponentType t) {
int32 N=sizeof(kMarkerMap)/sizeof(kMarkerMap[0]);
for(int i=0; i<N; i++) {
if(kMarkerMap[i].key == t)
return kMarkerMap[i].value;
}
KALDI_ERR << "Unknown type" << t;
return NULL;
}
Component::ComponentType Component::MarkerToType(const std::string& s) {
int32 N=sizeof(kMarkerMap)/sizeof(kMarkerMap[0]);
for(int i=0; i<N; i++) {
if(0 == strcmp(kMarkerMap[i].value,s.c_str()))
return kMarkerMap[i].key;
}
KALDI_ERR << "Unknown marker" << s;
return kUnknown;
}
Component* Component::Read(std::istream& is, bool binary, Nnet* nnet) {
int32 dim_out, dim_in;
std::string token;
int first_char = PeekMarker(is,binary);
if(first_char == EOF) return NULL;
ReadMarker(is,binary,&token);
Component::ComponentType comp_type = Component::MarkerToType(token);
ReadBasicType(is,binary,&dim_out);
ReadBasicType(is,binary,&dim_in);
Component* p_comp;
switch(comp_type) {
case Component::kBiasedLinearity :
p_comp = new BiasedLinearity(dim_in,dim_out,nnet);
break;
case Component::kSigmoid :
p_comp = new Sigmoid(dim_in,dim_out,nnet);
break;
case Component::kSoftmax :
p_comp = new Softmax(dim_in,dim_out,nnet);
break;
case Component::kUnknown :
default :
KALDI_ERR << "Missing type: " << token;
}
p_comp->ReadData(is,binary);
return p_comp;
}
void Component::Write(std::ostream& os, bool binary) const {
WriteMarker(os,binary,Component::TypeToMarker(GetType()));
WriteBasicType(os,binary,OutputDim());
WriteBasicType(os,binary,InputDim());
if(!binary) os << "\n";
this->WriteData(os,binary);
}
} // namespace

273
src/nnet/nnet-component.h Normal file
Просмотреть файл

@ -0,0 +1,273 @@
// nnet/nnet-component.h
// Copyright 2011 Karel Vesely
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_NNET_COMPONENT_H
#define KALDI_NNET_COMPONENT_H
#include "base/kaldi-common.h"
#include "matrix/matrix-lib.h"
//#include "nnet/nnet-nnet.h"
#include <iostream>
namespace kaldi {
//declare the nnet class so we can declare pointer
class Nnet;
/**
* Abstract class, basic element of the network,
* it is a box with defined inputs, outputs,
* and tranformation functions interface.
*
* It is able to propagate and backpropagate
* exact implementation is to be implemented in descendants.
*
* The data buffers are not included
* and will be managed from outside.
*/
class Component
{
//////////////////////////////////////////////////////////////
// Disable copy construction and assignment
private:
Component(Component&);
Component& operator=(Component&);
//////////////////////////////////////////////////////////////
// Polymorphic Component RTTI
public:
/// Types of the net components
typedef enum {
kUnknown = 0x0,
kUpdatableComponent = 0x0100,
kBiasedLinearity,
kSharedLinearity,
kActivationFunction = 0x0200,
kSoftmax,
kSigmoid,
kTranform = 0x0400,
kExpand,
kCopy,
kTranspose,
kBlockLinearity,
kBias,
kWindow,
kLog
} ComponentType;
/// Pair of type and marker
struct key_value {
const Component::ComponentType key;
const char* value;
};
/// Mapping of types and markers
static const struct key_value kMarkerMap[];
/// Convert component type to marker
static const char* TypeToMarker(ComponentType t);
/// Convert marker to component type
static ComponentType MarkerToType(const std::string& s);
//////////////////////////////////////////////////////////////
// Constructor & Destructor
public:
Component(MatrixIndexT input_dim, MatrixIndexT output_dim, Nnet* nnet)
: input_dim_(input_dim), output_dim_(output_dim), nnet_(nnet)
{ }
virtual ~Component()
{ }
//////////////////////////////////////////////////////////////
// Public interface
public:
/// Get Type Identification of the component
virtual ComponentType GetType() const = 0;
/// Check if contains trainable parameters
virtual bool IsUpdatable() const {
return false;
}
/// Get size of input vectors
MatrixIndexT InputDim() const {
return input_dim_;
}
/// Get size of output vectors
MatrixIndexT OutputDim() const {
return output_dim_;
}
/// Perform forward pass propagateion Input->Output
void Propagate(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out);
/// Perform backward pass propagateion ErrorInput->ErrorOutput
void Backpropagate(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err);
/// Read component from stream
static Component* Read(std::istream& is, bool binary, Nnet* nnet);
/// Write component to stream
void Write(std::ostream& os, bool binary) const;
///////////////////////////////////////////////////////////////
// abstract interface for propagation/backpropagation
protected:
/// Forward pass transformation (to be implemented by descendents...)
virtual void PropagateFnc(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) = 0;
/// Backward pass transformation (to be implemented by descendents...)
virtual void BackpropagateFnc(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err) = 0;
/// Reads the component content
virtual void ReadData(std::istream& is, bool binary)
{ }
/// Writes the component content
virtual void WriteData(std::ostream& os, bool binary) const
{ }
///////////////////////////////////////////////////////////////
// data members
protected:
MatrixIndexT input_dim_; ///< Size of input vectors
MatrixIndexT output_dim_; ///< Size of output vectors
Nnet* nnet_; ///< Pointer to the whole network
};
/**
* Class UpdatableComponent is a Component which has
* trainable parameters and contains some global
* parameters for stochastic gradient descent
* (learnrate,momenutm,L2,L1)
*/
class UpdatableComponent : public Component {
//////////////////////////////////////////////////////////////
// Constructor & Destructor
public:
UpdatableComponent(MatrixIndexT input_dim, MatrixIndexT output_dim, Nnet* nnet)
: Component(input_dim,output_dim,nnet),
learn_rate_(0.0), momentum_(0.0), l2_penalty_(0.0), l1_penalty_(0.0)
{ }
virtual ~UpdatableComponent()
{ }
//////////////////////////////////////////////////////////////
// Public interface
public:
/// Check if contains trainable parameters
bool IsUpdatable() const {
return true;
}
/// Compute gradient and update parameters
virtual void Update(const Matrix<BaseFloat>& input, const Matrix<BaseFloat>& err) = 0;
/// Sets the learning rate of gradient descent
void LearnRate(BaseFloat lrate) {
learn_rate_ = lrate;
}
/// Gets the learning rate of gradient descent
BaseFloat LearnRate() {
return learn_rate_;
}
/// Sets momentum
void Momentum(BaseFloat mmt) {
momentum_ = mmt;
}
/// Gets momentum
BaseFloat Momentum() {
return momentum_;
}
/// Sets L2 penalty (weight decay)
void L2Penalty(BaseFloat l2) {
l2_penalty_ = l2;
}
/// Gets L2 penalty (weight decay)
BaseFloat L2Penalty() {
return l2_penalty_;
}
/// Sets L1 penalty (sparisity promotion)
void L1Penalty(BaseFloat l1) {
l1_penalty_ = l1;
}
/// Gets L1 penalty (sparisity promotion)
BaseFloat L1Penalty() {
return l1_penalty_;
}
protected:
BaseFloat learn_rate_; ///< learning rate (0.0..0.01)
BaseFloat momentum_; ///< momentum value (0.0..1.0)
BaseFloat l2_penalty_; ///< L2 regularization constant (0.0..1e-4)
BaseFloat l1_penalty_; ///< L1 regularization constant (0.0..1e-4)
};
//////////////////////////////////////////////////////////////////////////
// INLINE FUNCTIONS
// Component::
inline void Component::Propagate(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) {
if(input_dim_ != in.NumCols()) {
KALDI_ERR << "Nonmatching dims, component:" << input_dim_ << " data:" << in.NumCols();
}
if(output_dim_ != out->NumCols() || in.NumRows() != out->NumRows()) {
out->Resize(in.NumRows(), output_dim_);
}
PropagateFnc(in, out);
}
inline void Component::Backpropagate(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err) {
if(output_dim_ != in_err.NumCols()) {
KALDI_ERR << "Nonmatching dims, component:" << output_dim_
<< " data:" << in_err.NumCols();
}
if(input_dim_ != out_err->NumCols() || in_err.NumRows() != out_err->NumRows()) {
out_err->Resize(in_err.NumRows(), input_dim_);
}
BackpropagateFnc(in_err, out_err);
}
//////////////////////////////////////////////////////////////////////////
// INLINE FUNCTIONS
// UpdatableComponent::
// nothing for now!
} // namespace kaldi
#endif

151
src/nnet/nnet-loss.cc Normal file
Просмотреть файл

@ -0,0 +1,151 @@
// nnet/nnet-loss.cc
// Copyright 2011 Karel Vesely
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "nnet/nnet-loss.h"
#include <sstream>
namespace kaldi {
void Xent::Eval(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* diff) {
KALDI_ASSERT(net_out.NumCols() == target.NumCols());
KALDI_ASSERT(net_out.NumRows() == target.NumRows());
diff->Resize(net_out.NumRows(),net_out.NumCols(),kUndefined);
//compute derivative wrt. activations of last layer of neurons
diff->CopyFromMat(net_out);
diff->AddMat(-1.0,target);
//we'll not produce per-frame classification accuracy for soft labels
correct_ = -1;
//compute xentropy
BaseFloat val;
for(int32 r=0; r<net_out.NumRows(); r++) {
for(int32 c=0; c<net_out.NumCols(); c++) {
val = -target(r,c)*log(net_out(r,c));
if(isinf(val)) val = 1e10;
loss_ += val;
}
}
frames_ += net_out.NumRows();
}
void Xent::Eval(const Matrix<BaseFloat>& net_out, const std::vector<int32>& target, Matrix<BaseFloat>* diff) {
KALDI_ASSERT(net_out.NumRows() == (int32)target.size());
//check the labels
int32 max=0;
std::vector<int32>::const_iterator it;
for(it=target.begin(); it!=target.end(); ++it) {
if(max < *it) max = *it;
}
if(max > net_out.NumCols()) {
KALDI_ERR << "Network has " << net_out.NumCols()
<< " outputs while having " << max << " labels";
}
//compute derivative wrt. activations of last layer of neurons
diff->Resize(net_out.NumRows(),net_out.NumCols(),kUndefined);
diff->CopyFromMat(net_out);
for(int32 r=0; r<(int32)target.size(); r++) {
KALDI_ASSERT(target.at(r) <= diff->NumCols());
(*diff)(r,target.at(r)-1) -= 1.0;
}
//we'll not produce per-frame classification accuracy for soft labels
correct_ += Correct(net_out,target);
//compute xentropy
BaseFloat val;
for(int32 r=0; r<net_out.NumRows(); r++) {
KALDI_ASSERT(target.at(r) <= net_out.NumCols());
val = -log(net_out(r,target.at(r)-1));
if(isinf(val)) val = 1e10;
loss_ += val;
}
frames_ += net_out.NumRows();
}
std::string Xent::Report() {
std::ostringstream oss;
oss << "Xent:" << loss_ << " frames:" << frames_
<< " err/frm:" << loss_/frames_;
if(correct_ >= 0.0) {
oss << " correct[" << 100.0*correct_/frames_ << "%]";
}
oss << std::endl;
return oss.str();
}
int32 Xent::Correct(const Matrix<BaseFloat>& net_out, const std::vector<int32>& target) {
int32 correct = 0;
for(int32 r=0; r<net_out.NumRows(); r++) {
BaseFloat max = -1;
int32 max_id = -1;
for(int32 c=0; c<net_out.NumCols(); c++) {
if(max < net_out(r,c)) {
max = net_out(r,c);
max_id = c;
}
}
if(target.at(r)-1 == max_id) {
correct++;
}
}
return correct;
}
void Mse::Eval(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* diff) {
KALDI_ASSERT(net_out.NumCols() == target.NumCols());
KALDI_ASSERT(net_out.NumRows() == target.NumRows());
diff->Resize(net_out.NumRows(),net_out.NumCols(),kUndefined);
//compute derivative w.r.t. neural nerwork outputs
diff->CopyFromMat(net_out);
diff->AddMat(-1.0,target);
//compute mean square error
BaseFloat val;
for(int32 r=0; r<net_out.NumRows(); r++) {
for(int32 c=0; c<net_out.NumCols(); c++) {
val = target(r,c) - net_out(r,c);
loss_ += val*val;
}
}
frames_ += net_out.NumRows();
}
std::string Mse::Report() {
std::ostringstream oss;
oss << "Mse:" << loss_ << " frames:" << frames_
<< " err/frm:" << loss_/frames_
<< std::endl;
return oss.str();
}
} // namespace

75
src/nnet/nnet-loss.h Normal file
Просмотреть файл

@ -0,0 +1,75 @@
// nnet/nnet-loss.h
// Copyright 2011 Karel Vesely
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_NNET_LOSS_H
#define KALDI_NNET_LOSS_H
#include "base/kaldi-common.h"
#include "matrix/matrix-lib.h"
namespace kaldi {
class Xent {
public:
Xent()
: frames_(0), correct_(0), loss_(0.0)
{ }
~Xent()
{ }
/// Evaluate cross entropy from hard labels
void Eval(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* diff);
/// Evaluate cross entropy from soft labels
void Eval(const Matrix<BaseFloat>& net_out, const std::vector<int32>& target, Matrix<BaseFloat>* diff);
/// Generate string with error report
std::string Report();
private:
int32 Correct(const Matrix<BaseFloat>& net_out, const std::vector<int32>& target);
private:
int32 frames_;
int32 correct_;
double loss_;
};
class Mse {
public:
Mse()
: frames_(0), loss_(0.0)
{ }
~Mse()
{ }
/// Evaluate mean square error from target values
void Eval(const Matrix<BaseFloat>& net_out, const Matrix<BaseFloat>& target, Matrix<BaseFloat>* diff);
/// Generate string with error report
std::string Report();
private:
int32 frames_;
double loss_;
};
} // namespace
#endif

217
src/nnet/nnet-nnet.cc Normal file
Просмотреть файл

@ -0,0 +1,217 @@
// nnet/nnet-nnet.cc
// Copyright 2011 Karel Vesely
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "nnet/nnet-nnet.h"
#include "nnet/nnet-component.h"
#include "nnet/nnet-activation.h"
#include "nnet/nnet-biasedlinearity.h"
namespace kaldi {
void Nnet::Propagate(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) {
KALDI_ASSERT(NULL != out);
if(LayerCount() == 0) {
out->Resize(in.NumRows(),in.NumCols(),kUndefined);
out->CopyFromMat(in);
return;
}
//we need at least L+1 input buffers
KALDI_ASSERT((int32)propagate_buf_.size() >= LayerCount()+1);
propagate_buf_[0].Resize(in.NumRows(),in.NumCols(),kUndefined);
propagate_buf_[0].CopyFromMat(in);
for(int32 i=0; i<(int32)nnet_.size(); i++) {
nnet_[i]->Propagate(propagate_buf_[i],&propagate_buf_[i+1]);
}
Matrix<BaseFloat>& mat = propagate_buf_[nnet_.size()];
out->Resize(mat.NumRows(),mat.NumCols(),kUndefined);
out->CopyFromMat(mat);
}
void Nnet::Backpropagate(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err) {
if(LayerCount() == 0) { KALDI_ERR << "Cannot backpropagate on empty network"; }
//we need at least L+1 input bufers
KALDI_ASSERT((int32)propagate_buf_.size() >= LayerCount()+1);
//we need at least L-1 error bufers
KALDI_ASSERT((int32)backpropagate_buf_.size() >= LayerCount()-1);
//find out when we can stop backprop
int32 backprop_stop = -1;
if(NULL == out_err) {
backprop_stop++;
while(1) {
if(nnet_[backprop_stop]->IsUpdatable()) {
if(0.0 != dynamic_cast<UpdatableComponent*>(nnet_[backprop_stop])->LearnRate()) {
break;
}
}
backprop_stop++;
if(backprop_stop == (int32)nnet_.size()) {
KALDI_ERR << "All layers have zero learning rate!";
break;
}
}
}
//disable!
backprop_stop=-1;
//////////////////////////////////////
// Backpropagation
//
//don't copy the in_err to buffers, use it as is...
int32 i = nnet_.size()-1;
if(nnet_[i]->IsUpdatable()) {
UpdatableComponent* uc = dynamic_cast<UpdatableComponent*>(nnet_[i]);
if(uc->LearnRate() > 0.0) {
uc->Update(propagate_buf_[i],in_err);
}
}
nnet_.back()->Backpropagate(in_err,&backpropagate_buf_[i-1]);
//backpropagate by using buffers
for(i--; i >= 1; i--) {
if(nnet_[i]->IsUpdatable()) {
UpdatableComponent* uc = dynamic_cast<UpdatableComponent*>(nnet_[i]);
if(uc->LearnRate() > 0.0) {
uc->Update(propagate_buf_[i],backpropagate_buf_[i]);
}
}
if(backprop_stop == i) break;
nnet_[i]->Backpropagate(backpropagate_buf_[i],&backpropagate_buf_[i-1]);
}
//update first layer
if(nnet_[0]->IsUpdatable() && 0 >= backprop_stop) {
UpdatableComponent* uc = dynamic_cast<UpdatableComponent*>(nnet_[0]);
if(uc->LearnRate() > 0.0) {
uc->Update(propagate_buf_[0],backpropagate_buf_[0]);
}
}
//now backpropagate through first layer, but only if asked to (by out_err pointer)
if(NULL != out_err) {
nnet_[0]->Backpropagate(backpropagate_buf_[0],out_err);
}
//
// End of Backpropagation
//////////////////////////////////////
}
void Nnet::Feedforward(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out) {
KALDI_ASSERT(NULL != out);
if(LayerCount() == 0) {
out->Resize(in.NumRows(),in.NumCols(),kUndefined);
out->CopyFromMat(in);
return;
}
//we need at least 2 input buffers
KALDI_ASSERT(propagate_buf_.size() >= 2);
//propagate by using exactly 2 auxiliary buffers
int32 L = 0;
nnet_[L]->Propagate(in,&propagate_buf_[L%2]);
for(L++; L<=LayerCount()-2; L++) {
nnet_[L]->Propagate(propagate_buf_[(L-1)%2],&propagate_buf_[L%2]);
}
nnet_[L]->Propagate(propagate_buf_[(L-1)%2],out);
}
void Nnet::Read(std::istream& in, bool binary) {
//get the network layers from a factory
Component *comp;
while(NULL != (comp = Component::Read(in,binary,this))) {
if(LayerCount() > 0 && nnet_.back()->OutputDim() != comp->InputDim()) {
KALDI_ERR << "Dimensionality mismatch!"
<< " Previous layer output:" << nnet_.back()->OutputDim()
<< " Current layer input:" << comp->InputDim();
}
nnet_.push_back(comp);
}
//create empty buffers
propagate_buf_.resize(LayerCount()+1);
backpropagate_buf_.resize(LayerCount()-1);
//reset learn rate
learn_rate_ = 0.0;
}
void Nnet::LearnRate(BaseFloat lrate, const char* lrate_factors) {
//split lrate_factors to a vector
std::vector<BaseFloat> lrate_factor_vec;
if(NULL != lrate_factors) {
char* copy = new char[strlen(lrate_factors)+1];
strcpy(copy, lrate_factors);
char* tok = NULL;
while(NULL != (tok = strtok((tok==NULL?copy:NULL),",:; "))) {
lrate_factor_vec.push_back(atof(tok));
}
delete copy;
}
//count trainable layers
int32 updatable = 0;
for(int i=0; i<LayerCount(); i++) {
if(nnet_[i]->IsUpdatable()) updatable++;
}
//check number of factors
if(lrate_factor_vec.size() > 0 && updatable != (int32)lrate_factor_vec.size()) {
KALDI_ERR << "Mismatch between number of trainable layers " << updatable
<< " and learn rate factors " << lrate_factor_vec.size();
}
//set learn rates
updatable=0;
for(int32 i=0; i<LayerCount(); i++) {
if(nnet_[i]->IsUpdatable()) {
BaseFloat lrate_scaled = lrate;
if(lrate_factor_vec.size() > 0) lrate_scaled *= lrate_factor_vec[updatable++];
dynamic_cast<UpdatableComponent*>(nnet_[i])->LearnRate(lrate_scaled);
}
}
//set global learn rate
learn_rate_ = lrate;
}
std::string Nnet::LearnRateString() {
std::ostringstream oss;
oss << "LEARN_RATE global: " << learn_rate_ << " individual: ";
for(int32 i=0; i<LayerCount(); i++) {
if(nnet_[i]->IsUpdatable()) {
oss << dynamic_cast<UpdatableComponent*>(nnet_[i])->LearnRate() << " ";
}
}
return oss.str();
}
} // namespace

223
src/nnet/nnet-nnet.h Normal file
Просмотреть файл

@ -0,0 +1,223 @@
// nnet/nnet-nnet.h
// Copyright 2011 Karel Vesely
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_NNET_NNET_H
#define KALDI_NNET_NNET_H
#include "base/kaldi-common.h"
#include "util/kaldi-io.h"
#include "matrix/matrix-lib.h"
#include "nnet/nnet-component.h"
#include <iostream>
#include <sstream>
#include <vector>
namespace kaldi {
class Nnet {
//////////////////////////////////////
// Typedefs
typedef std::vector<Component*> NnetType;
//////////////////////////////////////////////////////////////
// Disable copy construction and assignment
private:
Nnet(Nnet&);
Nnet& operator=(Nnet&);
//////////////////////////////////////////////////////////////
// Constructor & Destructor
public:
Nnet()
{ }
~Nnet(); //{ } later...
//////////////////////////////////////////////////////////////
// Public interface
public:
/// Perform forward pass through the network
void Propagate(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out);
/// Perform backward pass through the network
void Backpropagate(const Matrix<BaseFloat>& in_err, Matrix<BaseFloat>* out_err);
/// Perform forward pass through the network, don't keep buffers (use it when not training)
void Feedforward(const Matrix<BaseFloat>& in, Matrix<BaseFloat>* out);
MatrixIndexT InputDim() const; ///< Dimensionality of the input features
MatrixIndexT OutputDim() const; ///< Dimensionality of the desired vectors
MatrixIndexT LayerCount() const { ///< Get number of layers
return nnet_.size();
}
Component* Layer(MatrixIndexT index) { ///< Access to individual layer
return nnet_[index];
}
int IndexOfLayer(const Component& comp) const; ///< Get the position of layer in network
/// Access to forward pass buffers
const std::vector<Matrix<BaseFloat> >& PropagateBuffer() const {
return propagate_buf_;
}
/// Access to backward pass buffers
const std::vector<Matrix<BaseFloat> >& BackpropagateBuffer() const {
return backpropagate_buf_;
}
/// Read the MLP from file (can add layers to exisiting instance of Nnet)
void Read(const std::string& file);
/// Read the MLP from stream (can add layers to exisiting instance of Nnet)
void Read(std::istream& in, bool binary);
/// Write MLP to file
void Write(const std::string& file, bool binary);
/// Write MLP to stream
void Write(std::ostream& out, bool binary);
/// Set the learning rate values to trainable layers,
/// factors can disable training of individual layers
void LearnRate(BaseFloat lrate, const char* lrate_factors);
/// Get the global learning rate value
BaseFloat LearnRate() {
return learn_rate_;
}
/// Get the string with real learning rate values
std::string LearnRateString();
void Momentum(BaseFloat mmt);
void L2Penalty(BaseFloat l2);
void L1Penalty(BaseFloat l1);
//////////////////////////////////////////////////////////////
// Private interface
private:
/// Creates a component by reading from stream, return NULL if no more components
static Component* ComponentFactory(std::istream& in, bool binary, Nnet* nnet);
/// Dumps individual component to stream
static void ComponentDumper(std::ostream& out, bool binary, const Component& comp);
private:
NnetType nnet_; ///< vector of all Component*, represents layers
std::vector<Matrix<BaseFloat> > propagate_buf_; ///< buffers for forward pass
std::vector<Matrix<BaseFloat> > backpropagate_buf_; ///< buffers for backward pass
BaseFloat learn_rate_; ///< global learning rate
};
//////////////////////////////////////////////////////////////////////////
// INLINE FUNCTIONS
// Nnet::
inline Nnet::~Nnet() {
//delete all the components
NnetType::iterator it;
for(it=nnet_.begin(); it!=nnet_.end(); ++it) {
delete *it;
}
}
inline MatrixIndexT Nnet::InputDim() const {
if(LayerCount() > 0) {
return nnet_.front()->InputDim();
} else {
KALDI_ERR << "No layers in MLP";
}
}
inline MatrixIndexT Nnet::OutputDim() const {
if(LayerCount() > 0) {
return nnet_.back()->OutputDim();
} else {
KALDI_ERR << "No layers in MLP";
}
}
inline int32 Nnet::IndexOfLayer(const Component& comp) const {
for(int32 i=0; i<LayerCount(); i++) {
if(&comp == nnet_[i]) return i;
}
KALDI_ERR << "Component:" << &comp
<< " type:" << comp.GetType()
<< " not found in the MLP";
return -1;
}
inline void Nnet::Read(const std::string& file) {
bool binary;
Input in(file,&binary);
Read(in.Stream(),binary);
in.Close();
}
inline void Nnet::Write(const std::string& file, bool binary) {
Output out(file, binary, true);
Write(out.Stream(),binary);
out.Close();
}
inline void Nnet::Write(std::ostream& out, bool binary) {
for(int32 i=0; i<LayerCount(); i++) {
nnet_[i]->Write(out,binary);
}
}
inline void Nnet::Momentum(BaseFloat mmt) {
for(int32 i=0; i<LayerCount(); i++) {
if(nnet_[i]->IsUpdatable()) {
dynamic_cast<UpdatableComponent*>(nnet_[i])->Momentum(mmt);
}
}
}
inline void Nnet::L2Penalty(BaseFloat l2) {
for(int32 i=0; i<LayerCount(); i++) {
if(nnet_[i]->IsUpdatable()) {
dynamic_cast<UpdatableComponent*>(nnet_[i])->L2Penalty(l2);
}
}
}
inline void Nnet::L1Penalty(BaseFloat l1) {
for(int32 i=0; i<LayerCount(); i++) {
if(nnet_[i]->IsUpdatable()) {
dynamic_cast<UpdatableComponent*>(nnet_[i])->L1Penalty(l1);
}
}
}
} //namespace kaldi
#endif

43
src/nnet/nnet-test.cc Normal file
Просмотреть файл

@ -0,0 +1,43 @@
// nnet/nnet-test.cc
// Copyright 2010 Karel Vesely
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "base/kaldi-common.h"
#include "nnet/nnet-component.h"
#include "nnet/nnet-nnet.h"
using namespace kaldi;
static void UnitTestSomething() {
KALDI_ERR << "Unimeplemented";
}
static void UnitTestNnet() {
try {
UnitTestSomething();
} catch (const std::exception& e) {
std::cerr << e.what();
}
}
int main() {
UnitTestNnet();
}

44
src/nnetbin/Makefile Normal file
Просмотреть файл

@ -0,0 +1,44 @@
all:
EXTRA_CXXFLAGS = -Wno-sign-compare
include ../kaldi.mk
BINFILES = nnet-train-xent-hardlab-perutt
OBJFILES =
all: $(BINFILES)
TESTFILES =
$(BINFILES): ../nnet/kaldi-nnet.a ../matrix/kaldi-matrix.a ../util/kaldi-util.a ../base/kaldi-base.a
# Rule below would expand to, e.g.:
# ../base/kaldi-base.a:
# make -c ../base kaldi-base.a
# -c option to make is same as changing directory.
%.a:
$(MAKE) -C ${@D} ${@F}
clean:
rm *.o *.a $(TESTFILES) $(BINFILES)
test: $(TESTFILES)
for x in $(TESTFILES); do ./$$x >&/dev/null || (echo "***test $$x failed***"; exit 1); done
echo Tests succeeded
.valgrind: $(TESTFILES)
depend:
-$(CXX) -M $(CXXFLAGS) *.cc > .depend.mk
# removing automatic making of "depend" as it's quite slow.
#.depend.mk: depend
-include .depend.mk

Просмотреть файл

@ -0,0 +1,155 @@
// nnet/nnet-train-xent-hardlab-perutt.cc
// Copyright 2011 Karel Vesely
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "nnet/nnet-nnet.h"
#include "nnet/nnet-loss.h"
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "util/timer.h"
int main(int argc, char *argv[])
{
using namespace kaldi;
try {
const char *usage =
"Perform iteration of Neural Network training by stochastic gradient descent.\n"
"Usage: nnet-train-xent-hardlab-perutt [options] <model-in> <feature-rspecifier> <alignments-rspecifier> [<model-out>]\n"
"e.g.: \n"
" nnet-train-xent-hardlab-perutt nnet.init scp:train.scp ark:train.ali nnet.iter1\n";
ParseOptions po(usage);
bool binary = false,
crossvalidate = false;
po.Register("binary", &binary, "Write output in binary mode");
po.Register("cross-validate", &crossvalidate, "Perform cross-validation (don't backpropagate)");
BaseFloat learn_rate = 0.008,
momentum = 0.0,
l2_penalty = 0.0,
l1_penalty = 0.0;
po.Register("learn-rate", &learn_rate, "Learning rate");
po.Register("momentum", &momentum, "Momentum");
po.Register("l2-penalty", &l2_penalty, "L2 penalty (weight decay)");
po.Register("l1-penalty", &l1_penalty, "L1 penalty (promote sparsity)");
std::string feature_transform;
po.Register("feature-transform", &feature_transform, "Feature transform Neural Network");
po.Read(argc, argv);
if (po.NumArgs() != 4-(crossvalidate?1:0)) {
po.PrintUsage();
exit(1);
}
std::string model_filename = po.GetArg(1),
feature_rspecifier = po.GetArg(2),
alignments_rspecifier = po.GetArg(3);
std::string target_model_filename;
if(!crossvalidate) {
target_model_filename = po.GetArg(4);
}
using namespace kaldi;
typedef kaldi::int32 int32;
Nnet nnet_transf;
if(feature_transform != "") {
nnet_transf.Read(feature_transform);
}
Nnet nnet;
nnet.Read(model_filename);
nnet.LearnRate(learn_rate,NULL);
nnet.Momentum(momentum);
nnet.L2Penalty(l2_penalty);
nnet.L1Penalty(l1_penalty);
kaldi::int64 tot_t = 0;
SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
RandomAccessInt32VectorReader alignments_reader(alignments_rspecifier);
Xent xent;
Matrix<BaseFloat> feats_transf, nnet_out, glob_err;
Timer tim;
KALDI_LOG << (crossvalidate?"CROSSVALIDATE":"TRAINING") << " STARTED";
int32 num_done = 0, num_no_alignment = 0, num_other_error = 0;
for (; !feature_reader.Done(); feature_reader.Next()) {
std::string key = feature_reader.Key();
if (!alignments_reader.HasKey(key)) {
num_no_alignment++;
} else {
const Matrix<BaseFloat> &mat = feature_reader.Value();
const std::vector<int32> &alignment = alignments_reader.Value(key);
//std::cout << mat;
if ((int32)alignment.size() != mat.NumRows()) {
KALDI_WARN << "Alignment has wrong size "<< (alignment.size()) << " vs. "<< (mat.NumRows());
num_other_error++;
continue;
}
if(num_done % 10000 == 0) std::cout << num_done << ", " << std::flush;
num_done++;
nnet_transf.Feedforward(mat,&feats_transf);
nnet.Propagate(feats_transf,&nnet_out);
//std::cout << "\nNETOUT" << nnet_out;
xent.Eval(nnet_out,alignment,&glob_err);
//std::cout << "\nALIGN" << alignment[0] << " "<< alignment[1]<< " "<< alignment[2];
//std::cout << "\nGLOBERR" << glob_err;
if(!crossvalidate) {
nnet.Backpropagate(glob_err,NULL);
}
tot_t += mat.NumRows();
}
}
if(!crossvalidate) {
nnet.Write(target_model_filename,binary);
}
std::cout << "\n" << std::flush;
KALDI_LOG << (crossvalidate?"CROSSVALIDATE":"TRAINING") << " FINISHED "
<< tim.Elapsed() << "s, fps" << tot_t/tim.Elapsed();
KALDI_LOG << "Done " << num_done << " files, " << num_no_alignment
<< " with no alignments, " << num_other_error
<< " with other errors.";
KALDI_LOG << xent.Report();
return 0;
} catch(const std::exception& e) {
std::cerr << e.what();
return -1;
}
}