Merge branch 'master' into chain

This commit is contained in:
Dan Povey 2016-04-06 10:40:40 -07:00
Родитель 2378680144 24f553cf04
Коммит ce708ea167
30 изменённых файлов: 434 добавлений и 249 удалений

Просмотреть файл

@ -1,4 +1,4 @@
#!/bin/bash
#!/bin/bash
#
if [ -f path.sh ]; then . path.sh; fi
@ -15,25 +15,12 @@ arpa_lm=$1
cp -r data/lang data/lang_test
# grep -v '<s> <s>' etc. is only for future-proofing this script. Our
# LM doesn't have these "invalid combinations". These can cause
# determinization failures of CLG [ends up being epsilon cycles].
# Note: remove_oovs.pl takes a list of words in the LM that aren't in
# our word list. Since our LM doesn't have any, we just give it
# /dev/null [we leave it in the script to show how you'd do it].
gunzip -c "$arpa_lm" | \
grep -v '<s> <s>' | \
grep -v '</s> <s>' | \
grep -v '</s> </s>' | \
arpa2fst - | fstprint | \
utils/remove_oovs.pl /dev/null | \
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=data/lang_test/words.txt \
--osymbols=data/lang_test/words.txt --keep_isymbols=false --keep_osymbols=false | \
fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang_test/G.fst
fstisstochastic data/lang_test/G.fst
arpa2fst --disambig-symbol=#0 \
--read-symbol-table=data/lang_test/words.txt - data/lang_test/G.fst
echo "Checking how stochastic G is (the first of these numbers should be small):"
fstisstochastic data/lang_test/G.fst
fstisstochastic data/lang_test/G.fst
## Check lexicon.
## just have a look and make sure it seems sane.
@ -61,4 +48,3 @@ fsttablecompose data/lang/L_disambig.fst data/lang_test/G.fst | \
fstisstochastic || echo LG is not stochastic
echo AMI_format_data succeeded.

Просмотреть файл

@ -0,0 +1,37 @@
Example scripts on how to use a pre-trained chain enlgish model and kaldi base code to recognize any number of wav files.
IMPORTANT: wav files must be in 16kHz, 16 bit little-endian format.
Model:
English pretrained model were released by Api.ai under Creative Commons Attribution-ShareAlike 4.0 International Public License.
- Acustic data is mostly mobile recorded data
- Language model is based on Assistant.ai logs and good for understanding short commands, like "Wake me up at 7 am"
For more details, visit https://github.com/api-ai/api-ai-english-asr-model
Usage:
- Ensure kaldi is compiled and this scripts are inside kaldi/egs/<subfolder>/ directory
- Run ./download-model.sh - to download pretrained chain model
- Run ./recognize-wav.sh test1.wav test2.wav to do recognition
- See output for recognition results
Using steps/nnet3/decode.sh script:
You can use kaldi steps/nnet3/decode.sh, which will decode data and calculate Word Error Rate (WER) for it.
Steps:
- Run recognize-wav.sh test1.wav test2.wav, it will make data dir, calculate mfcc features for it and do decoding, you need only first two steps out of it
- If you want WER then edit data/test-corpus/text and replace NO_TRANSCRIPTION with expected text transcription for every wav file
- steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --cmd run.pl --nj 1 exp/api.ai-model/ data/test-corpus/ exp/api.ai-model/decode/
- See exp/api.ai-model/decode/wer* files for WER and exp/api.ai-model/decode/log/ files for decoding output
Online Decoder:
At the moment kaldi does not support online decoding for nnet3 models, but decoders can be found here https://github.com/api-ai/kaldi/ .
See http://kaldi.sourceforge.net/online_decoding.html for more information about kaldi online decoding.
Steps:
- Run ./local/create-corpus.sh data/test-corpus/ test1.wav test2.wav (or just run recognize-wav.sh) to create corpus
- If you want WER then edit data/test-corpus/text and replace NO_TRANSCRIPTION with expected text transcription for every wav file
- Make config file exp/api.ai-model/online.conf with following content
==CONTENT START==
--feature-type=mfcc
--mfcc-config=exp/api.ai-model/mfcc.conf
==CONTENT END==
- Run steps/online/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --cmd run.pl --nj 1 exp/api.ai-model/ data/test-corpus/ exp/api.ai-model/decode/
- See exp/api.ai-model/decode/wer* files for WER and exp/api.ai-model/decode/log/ files for decoding output

Просмотреть файл

@ -0,0 +1,24 @@
#!/bin/bash
# Downlaods Api.ai chain model into exp/api.ai-model/ (will replace one if exists)
DOWNLOAD_URL="https://api.ai/downloads/api.ai-kaldi-asr-model.zip"
echo "Downloading model"
wget -N $DOWNLOAD_URL || ( echo "Unable to download model: $DOWNLOAD_URL" && exit 1 );
echo "Unpacking model"
unzip api.ai-kaldi-asr-model.zip || ( echo "Unable to extract api.ai-kaldi-asr-model.zip" && exit 1 );
echo "Moving model to exp/api.ai-model/"
if [ ! -d exp ]; then
mkdir exp;
fi;
if [ -d exp/api.ai-model ]; then
echo "Found existing model, removing";
rm -rf exp/api.ai-model/
fi
mv api.ai-kaldi-asr-model exp/api.ai-model || ( echo "Unable to move model to exp/" && exit 1 )
echo "Model is ready to use use recognize-wav.sh to do voice recognition"

Просмотреть файл

@ -0,0 +1,50 @@
#!/bin/bash
# Checking arguments
if [ $# -le 1 ]; then
echo "Use $0 <datadir> test1.wav [test2.wav] ..."
echo " $0 data/test-corpus test1.wav test2.wav"
exit 0;
fi
CORPUS=$1
shift
for file in "$@"; do
if [[ "$file" != *.wav ]]; then
echo "Expecting .wav files, got $file"
exit 1;
fi
if [ ! -f "$file" ]; then
echo "$file not found";
exit 1;
fi
done;
echo "Initilizing $CORPUS"
if [ ! -d "$CORPUS" ]; then
echo "Creating $CORPUS directory"
mkdir -p "$CORPUS" || ( echo "Unable to create data dir" && exit 1 )
fi;
wav_scp="$CORPUS/wav.scp"
spk2utt="$CORPUS/spk2utt"
utt2spk="$CORPUS/utt2spk"
text="$CORPUS/text"
#nulling files
cat </dev/null >$wav_scp
cat </dev/null >$spk2utt
cat </dev/null >$utt2spk
cat </dev/null >$text
rm $CORPUS/feats.scp 2>/dev/null;
rm $CORPUS/cmvn.scp 2>/dev/null;
for file in "$@"; do
id=$(echo $file | sed -e 's/ /_/g')
echo "$id $file" >>$wav_scp
echo "$id $id" >>$spk2utt
echo "$id $id" >>$utt2spk
echo "$id NO_TRANSRIPTION" >>$text
done;

Просмотреть файл

@ -0,0 +1 @@
../../wsj/s5/local/score.sh

6
egs/apiai_decode/s5/path.sh Executable file
Просмотреть файл

@ -0,0 +1,6 @@
export KALDI_ROOT=`pwd`/../../..
[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH
[ ! -f $KALDI_ROOT/src/path.sh ] && echo >&2 "The standard file $KALDI_ROOT/src/path.sh is not present -> Exit!" && exit 1
. $KALDI_ROOT/src/path.sh
export LC_ALL=C

Просмотреть файл

@ -0,0 +1,53 @@
#!/bin/bash
# Copyright 2016 Api.ai (Author: Ilya Platonov)
# Apache 2.0
# This script demonstrates kaldi decoding using pretrained model. It will decode list of wav files.
#
# IMPORTANT: wav files must be in 16kHz, 16 bit little-endian format.
#
# This script tries to follow with what other scripts are doing in terms of directory structures and data handling.
#
# Use ./download-model.sh script to download asr model
# See https://github.com/api-ai/api-ai-english-asr-model for details about a model and how to use it.
. path.sh
MODEL_DIR="exp/api.ai-model"
DATA_DIR="data/test-corpus"
echo "///////"
echo "// IMPORTANT: wav files must be in 16kHz, 16 bit little-endian format."
echo "//////";
for file in final.mdl HCLG.fst words.txt frame_subsampling_factor; do
if [ ! -f $MODEL_DIR/$file ]; then
echo "$MODEL_DIR/$file not found, use ./download-model.sh"
exit 1;
fi
done;
for app in nnet3-latgen-faster apply-cmvn lattice-scale; do
command -v $app >/dev/null 2>&1 || { echo >&2 "$app not found, is kaldi compiled?"; exit 1; }
done;
local/create-corpus.sh $DATA_DIR $@ || exit 1;
echo "///////"
echo "// Computing mfcc and cmvn (cmvn is not really used)"
echo "//////";
steps/make_mfcc.sh --nj 1 --mfcc-config $MODEL_DIR/mfcc.conf \
--cmd "run.pl" $DATA_DIR exp/make_mfcc exp/mfcc || { echo "Unable to calculate mfcc, ensure 16kHz, 16 bit little-endian wav format or see log"; exit 1; };
steps/compute_cmvn_stats.sh $DATA_DIR exp/make_mfcc/ exp/mfcc || exit 1;
echo "///////"
echo "// Doing decoding (see log for results)"
echo "//////";
frame_subsampling_factor=$(cat $MODEL_DIR/frame_subsampling_factor)
nnet3-latgen-faster --frame-subsampling-factor=$frame_subsampling_factor --frames-per-chunk=50 --extra-left-context=0 \
--extra-right-context=0 --extra-left-context-initial=-1 --extra-right-context-final=-1 \
--minimize=false --max-active=7000 --min-active=200 --beam=15.0 --lattice-beam=8.0 \
--acoustic-scale=1.0 --allow-partial=true \
--word-symbol-table=$MODEL_DIR/words.txt $MODEL_DIR/final.mdl $MODEL_DIR//HCLG.fst \
"ark,s,cs:apply-cmvn --norm-means=false --norm-vars=false --utt2spk=ark:$DATA_DIR/utt2spk scp:$DATA_DIR/cmvn.scp scp:$DATA_DIR/feats.scp ark:- |" \
"ark:|lattice-scale --acoustic-scale=10.0 ark:- ark:- >exp/lat.1"

1
egs/apiai_decode/s5/steps Symbolic link
Просмотреть файл

@ -0,0 +1 @@
../../wsj/s5/steps/

1
egs/apiai_decode/s5/utils Symbolic link
Просмотреть файл

@ -0,0 +1 @@
../../wsj/s5/utils/

Просмотреть файл

@ -18,40 +18,23 @@ fi
lm_dir=$1
tmpdir=data/local/lm_tmp
lexicon=data/local/lang_tmp/lexiconp.txt
mkdir -p $tmpdir
# This loop was taken verbatim from wsj_format_data.sh, and I'm leaving it in place in
# case we decide to add more language models at some point
for lm_suffix in tgpr; do
test=data/lang_test_${lm_suffix}
mkdir -p $test
for f in phones.txt words.txt phones.txt L.fst L_disambig.fst phones oov.txt oov.int; do
for f in phones.txt words.txt phones.txt L.fst L_disambig.fst phones topo oov.txt oov.int; do
cp -r data/lang/$f $test
done
gunzip -c $lm_dir/lm_${lm_suffix}.arpa.gz |\
utils/find_arpa_oovs.pl $test/words.txt > $tmpdir/oovs_${lm_suffix}.txt || exit 1
# grep -v '<s> <s>' because the LM seems to have some strange and useless
# stuff in it with multiple <s>'s in the history. Encountered some other similar
# things in a LM from Geoff. Removing all "illegal" combinations of <s> and </s>,
# which are supposed to occur only at being/end of utt. These can cause
# determinization failures of CLG [ends up being epsilon cycles].
gunzip -c $lm_dir/lm_${lm_suffix}.arpa.gz | \
grep -v '<s> <s>' | \
grep -v '</s> <s>' | \
grep -v '</s> </s>' | \
arpa2fst - | fstprint | \
utils/remove_oovs.pl $tmpdir/oovs_${lm_suffix}.txt | \
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$test/words.txt \
--osymbols=$test/words.txt --keep_isymbols=false --keep_osymbols=false | \
fstrmepsilon | fstarcsort --sort_type=ilabel > $test/G.fst
arpa2fst --disambig-symbol=#0 \
--read-symbol-table=$test/words.txt - $test/G.fst
utils/validate_lang.pl $test || exit 1;
done
echo "Succeeded in formatting data."
rm -r $tmpdir
exit 0

Просмотреть файл

@ -49,24 +49,9 @@ for lm_suffix in tgsmall tgmed; do
test=${src_dir}_test_${lm_suffix}
mkdir -p $test
cp -r ${src_dir}/* $test
gunzip -c $lm_dir/lm_${lm_suffix}.arpa.gz |\
utils/find_arpa_oovs.pl $test/words.txt > $tmpdir/oovs_${lm_suffix}.txt || exit 1
# grep -v '<s> <s>' because the LM seems to have some strange and useless
# stuff in it with multiple <s>'s in the history. Encountered some other
# similar things in a LM from Geoff. Removing all "illegal" combinations of
# <s> and </s>, which are supposed to occur only at being/end of utt. These
# can cause determinization failures of CLG [ends up being epsilon cycles].
gunzip -c $lm_dir/lm_${lm_suffix}.arpa.gz | \
grep -v '<s> <s>' | \
grep -v '</s> <s>' | \
grep -v '</s> </s>' | \
arpa2fst - | fstprint | \
utils/remove_oovs.pl $tmpdir/oovs_${lm_suffix}.txt | \
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$test/words.txt \
--osymbols=$test/words.txt --keep_isymbols=false --keep_osymbols=false | \
fstrmepsilon | fstarcsort --sort_type=ilabel > $test/G.fst
arpa2fst --disambig-symbol=#0 \
--read-symbol-table=$test/words.txt - $test/G.fst
utils/validate_lang.pl --skip-determinization-check $test || exit 1;
done

Просмотреть файл

@ -40,10 +40,10 @@ open(DUR3, ">$dir/3sec") || die "Failed opening output file $dir/3sec";
open(DUR10, ">$dir/10sec") || die "Failed opening output file $dir/10sec";
open(DUR30, ">$dir/30sec") || die "Failed opening output file $dir/30sec";
my $key_str = `wget -qO- "http://www.itl.nist.gov/iad/mig/tests/lang/2007/lid07key_v5.txt"`;
my $key_str = `wget -qO- "http://www.openslr.org/resources/23/lre07_key.txt"`;
@key_lines = split("\n",$key_str);
%utt2lang = ();
%utt2dur = ();
%utt2lang = ();
%utt2dur = ();
foreach (@key_lines) {
@words = split(' ', $_);
if (index($words[0], "#") == -1) {

Просмотреть файл

@ -27,7 +27,7 @@ tmpdir=data/local/lm_tmp
lexicon=data/local/lang${lang_suffix}_tmp/lexiconp.txt
mkdir -p $tmpdir
for x in train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do
for x in train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do
mkdir -p data/$x
cp $srcdir/${x}_wav.scp data/$x/wav.scp || exit 1;
cp $srcdir/$x.txt data/$x/text || exit 1;
@ -49,22 +49,8 @@ for lm_suffix in bg tgpr tg bg_5k tgpr_5k tg_5k; do
cp -r data/lang${lang_suffix}/* $test || exit 1;
gunzip -c $lmdir/lm_${lm_suffix}.arpa.gz | \
utils/find_arpa_oovs.pl $test/words.txt > $tmpdir/oovs_${lm_suffix}.txt
# grep -v '<s> <s>' because the LM seems to have some strange and useless
# stuff in it with multiple <s>'s in the history. Encountered some other similar
# things in a LM from Geoff. Removing all "illegal" combinations of <s> and </s>,
# which are supposed to occur only at being/end of utt. These can cause
# determinization failures of CLG [ends up being epsilon cycles].
gunzip -c $lmdir/lm_${lm_suffix}.arpa.gz | \
grep -v '<s> <s>' | \
grep -v '</s> <s>' | \
grep -v '</s> </s>' | \
arpa2fst - | fstprint | \
utils/remove_oovs.pl $tmpdir/oovs_${lm_suffix}.txt | \
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$test/words.txt \
--osymbols=$test/words.txt --keep_isymbols=false --keep_osymbols=false | \
fstrmepsilon | fstarcsort --sort_type=ilabel > $test/G.fst
arpa2fst --disambig-symbol=#0 \
--read-symbol-table=$test/words.txt - $test/G.fst
utils/validate_lang.pl --skip-determinization-check $test || exit 1;
done

Просмотреть файл

@ -45,17 +45,13 @@ fi
# Be careful: this time we dispense with the grep -v '<s> <s>' so this might
# not work for LMs generated from all toolkits.
gunzip -c $lm_srcdir_3g/lm_pr6.0.gz | \
arpa2fst - | fstprint | \
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$lang/words.txt \
--osymbols=$lang/words.txt --keep_isymbols=false --keep_osymbols=false | \
fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang${lang_suffix}_test_bd_tgpr/G.fst || exit 1;
arpa2fst --disambig-symbol=#0 \
--read-symbol-table=$lang/words.txt - data/lang${lang_suffix}_test_bd_tgpr/G.fst || exit 1;
fstisstochastic data/lang${lang_suffix}_test_bd_tgpr/G.fst
gunzip -c $lm_srcdir_3g/lm_unpruned.gz | \
arpa2fst - | fstprint | \
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$lang/words.txt \
--osymbols=$lang/words.txt --keep_isymbols=false --keep_osymbols=false | \
fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang${lang_suffix}_test_bd_tg/G.fst || exit 1;
arpa2fst --disambig-symbol=#0 \
--read-symbol-table=$lang/words.txt - data/lang${lang_suffix}_test_bd_tg/G.fst || exit 1;
fstisstochastic data/lang${lang_suffix}_test_bd_tg/G.fst
# Build ConstArpaLm for the unpruned language model.
@ -65,10 +61,8 @@ gunzip -c $lm_srcdir_3g/lm_unpruned.gz | \
--unk-symbol=$unk - data/lang${lang_suffix}_test_bd_tgconst/G.carpa || exit 1
gunzip -c $lm_srcdir_4g/lm_unpruned.gz | \
arpa2fst - | fstprint | \
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$lang/words.txt \
--osymbols=$lang/words.txt --keep_isymbols=false --keep_osymbols=false | \
fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang${lang_suffix}_test_bd_fg/G.fst || exit 1;
arpa2fst --disambig-symbol=#0 \
--read-symbol-table=$lang/words.txt - data/lang${lang_suffix}_test_bd_fg/G.fst || exit 1;
fstisstochastic data/lang${lang_suffix}_test_bd_fg/G.fst
# Build ConstArpaLm for the unpruned language model.
@ -78,10 +72,8 @@ gunzip -c $lm_srcdir_4g/lm_unpruned.gz | \
--unk-symbol=$unk - data/lang${lang_suffix}_test_bd_fgconst/G.carpa || exit 1
gunzip -c $lm_srcdir_4g/lm_pr7.0.gz | \
arpa2fst - | fstprint | \
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$lang/words.txt \
--osymbols=$lang/words.txt --keep_isymbols=false --keep_osymbols=false | \
fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang${lang_suffix}_test_bd_fgpr/G.fst || exit 1;
arpa2fst --disambig-symbol=#0 \
--read-symbol-table=$lang/words.txt - data/lang${lang_suffix}_test_bd_fgpr/G.fst || exit 1;
fstisstochastic data/lang${lang_suffix}_test_bd_fgpr/G.fst
exit 0;

Просмотреть файл

@ -42,11 +42,11 @@ Allowed options:
(default = "<***>")
--wer-cutoff : Ignore segments with WER higher than the specified value.
-1 means no segment will be ignored. (default = -1)
--use-silence-midpoints : Set to 1 if you want to use silence midpoints
--use-silence-midpoints : Set to 1 if you want to use silence midpoints
instead of min_sil_length for silence overhang.(default 0)
--force-correct-boundary-words : Set to zero if the segments will not be
--force-correct-boundary-words : Set to zero if the segments will not be
required to have boundary words to be correct. Default 1
--aligned-ctm-filename : If set, the intermediate aligned ctm
--aligned-ctm-filename : If set, the intermediate aligned ctm
is saved to this file
EOU
@ -56,7 +56,7 @@ my $min_sil_length = 0.5;
my $separator = ";";
my $special_symbol = "<***>";
my $wer_cutoff = -1;
my $use_silence_midpoints = 0;
my $use_silence_midpoints = 0;
my $force_correct_boundary_words = 1;
my $aligned_ctm_filename = "";
GetOptions(
@ -122,13 +122,13 @@ sub PrintSegment {
# Works out the surrounding silence.
my $index = $seg_start_index - 1;
while ($index >= 0 && $aligned_ctm->[$index]->[0] eq
while ($index >= 0 && $aligned_ctm->[$index]->[0] eq
"<eps>" && $aligned_ctm->[$index]->[3] == 0) {
$index -= 1;
}
my $left_of_segment_has_deletion = "false";
$left_of_segment_has_deletion = "true"
if ($index > 0 && $aligned_ctm->[$index-1]->[0] ne "<eps>"
my $left_of_segment_has_deletion = "false";
$left_of_segment_has_deletion = "true"
if ($index > 0 && $aligned_ctm->[$index-1]->[0] ne "<eps>"
&& $aligned_ctm->[$index-1]->[3] == 0);
my $pad_start_sil = ($aligned_ctm->[$seg_start_index]->[1] -
@ -141,11 +141,11 @@ sub PrintSegment {
my $right_of_segment_has_deletion = "false";
$index = $seg_end_index + 1;
while ($index < scalar(@{$aligned_ctm}) &&
$aligned_ctm->[$index]->[0] eq "<eps>" &&
$aligned_ctm->[$index]->[0] eq "<eps>" &&
$aligned_ctm->[$index]->[3] == 0) {
$index += 1;
}
$right_of_segment_has_deletion = "true"
$right_of_segment_has_deletion = "true"
if ($index < scalar(@{$aligned_ctm})-1 && $aligned_ctm->[$index+1]->[0] ne
"<eps>" && $aligned_ctm->[$index - 1]->[3] > 0);
my $pad_end_sil = ($aligned_ctm->[$index - 1]->[1] +
@ -155,7 +155,7 @@ sub PrintSegment {
if (($right_of_segment_has_deletion eq "true") || !$use_silence_midpoints) {
if ($pad_end_sil > $min_sil_length / 2.0) {
$pad_end_sil = $min_sil_length / 2.0;
}
}
}
my $seg_start = $aligned_ctm->[$seg_start_index]->[1] - $pad_start_sil;
@ -270,6 +270,7 @@ sub ProcessWav {
$current_ctm, $current_align, $SO, $TO, $ACT) = @_;
my $wav_id = $current_ctm->[0]->[0];
my $channel_id = $current_ctm->[0]->[1];
defined($wav_id) || die "Error: empty wav section\n";
# First, we have to align the ctm file to the Levenshtein alignment.
@ -324,8 +325,8 @@ sub ProcessWav {
# Save the aligned CTM if needed
if(defined($ACT)){
for (my $i=0; $i<=$#aligned_ctm; $i++) {
print $ACT "$aligned_ctm[$i][0] $aligned_ctm[$i][1] ";
for (my $i = 0; $i <= $#aligned_ctm; $i++) {
print $ACT "$wav_id $channel_id $aligned_ctm[$i][0] $aligned_ctm[$i][1] ";
print $ACT "$aligned_ctm[$i][2] $aligned_ctm[$i][3]\n";
}
}
@ -346,8 +347,8 @@ sub ProcessWav {
# length, and if there are no alignment error around it. We also make sure
# that segment contains actual words, instead of pure silence.
if ($aligned_ctm[$x]->[0] eq "<eps>" &&
$aligned_ctm[$x]->[2] >= $min_sil_length
&& (($force_correct_boundary_words && $lcorrect eq "true" &&
$aligned_ctm[$x]->[2] >= $min_sil_length
&& (($force_correct_boundary_words && $lcorrect eq "true" &&
$rcorrect eq "true") || !$force_correct_boundary_words)) {
if ($current_seg_length <= $max_seg_length &&
$current_seg_length >= $min_seg_length) {
@ -379,7 +380,7 @@ sub ProcessWav {
# 011 A 3.39 0.23 SELL
# 011 A 3.62 0.18 OFF
# 011 A 3.83 0.45 ASSETS
#
#
# Output ctm:
# 011 A 3.39 0.23 SELL
# 011 A 3.62 0.18 OFF

Просмотреть файл

@ -111,10 +111,8 @@ while read line; do
if (invoc[$x]) { printf("%s ", $x); } else { printf("%s ", oov); } }
printf("\n"); }' > $wdir/text
ngram-count -text $wdir/text -order $ngram_order "$srilm_options" -lm - |\
arpa2fst - | fstprint | utils/eps2disambig.pl | utils/s2eps.pl |\
fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \
--keep_isymbols=false --keep_osymbols=false |\
fstrmepsilon | fstarcsort --sort_type=ilabel > $wdir/G.fst || exit 1;
arpa2fst --disambig-symbol=#0 \
--read-symbol-table=$lang/words.txt - $wdir/G.fst || exit 1;
fi
fstisstochastic $wdir/G.fst || echo "$0: $uttid/G.fst not stochastic."
@ -134,7 +132,7 @@ while read line; do
make-h-transducer --disambig-syms-out=$wdir/disambig_tid.int \
--transition-scale=$tscale $wdir/ilabels_${N}_${P} \
$model_dir/tree $model_dir/final.mdl > $wdir/Ha.fst
$model_dir/tree $model_dir/final.mdl > $wdir/Ha.fst
# Builds HCLGa.fst
fsttablecompose $wdir/Ha.fst $wdir/CLG.fst | \
@ -143,10 +141,10 @@ while read line; do
fstminimizeencoded > $wdir/HCLGa.fst
fstisstochastic $wdir/HCLGa.fst ||\
echo "$0: $uttid/HCLGa.fst is not stochastic"
add-self-loops --self-loop-scale=$loopscale --reorder=true \
$model_dir/final.mdl < $wdir/HCLGa.fst > $wdir/HCLG.fst
if [ $tscale == 1.0 -a $loopscale == 1.0 ]; then
fstisstochastic $wdir/HCLG.fst ||\
echo "$0: $uttid/HCLG.fst is not stochastic."

Просмотреть файл

@ -338,7 +338,7 @@ def TrainNewModels(dir, iter, num_jobs, num_archives_processed, num_archives,
left_context, right_context, min_deriv_time,
momentum, max_param_change,
shuffle_buffer_size, num_chunk_per_minibatch,
run_opts):
cache_read_opt, run_opts):
# We cannot easily use a single parallel SGE job to do the main training,
# because the computation of which archive and which --frame option
# to use for each job is a little complex, so we spawn each one separately.
@ -353,9 +353,15 @@ def TrainNewModels(dir, iter, num_jobs, num_archives_processed, num_archives,
# the other indexes from.
archive_index = (k % num_archives) + 1 # work out the 1-based archive index.
cache_write_opt = ""
if job == 1:
# an option for writing cache (storing pairs of nnet-computations and
# computation-requests) during training.
cache_write_opt="--write-cache={dir}/cache.{iter}".format(dir=dir, iter=iter+1)
process_handle = RunKaldiCommand("""
{command} {train_queue_opt} {dir}/log/train.{iter}.{job}.log \
nnet3-train {parallel_train_opts} \
nnet3-train {parallel_train_opts} {cache_read_opt} {cache_write_opt} \
--print-interval=10 --momentum={momentum} \
--max-param-change={max_param_change} \
--optimization.min-deriv-time={min_deriv_time} "{raw_model}" \
@ -365,6 +371,7 @@ def TrainNewModels(dir, iter, num_jobs, num_archives_processed, num_archives,
train_queue_opt = run_opts.train_queue_opt,
dir = dir, iter = iter, next_iter = iter + 1, job = job,
parallel_train_opts = run_opts.parallel_train_opts,
cache_read_opt = cache_read_opt, cache_write_opt = cache_write_opt,
momentum = momentum, max_param_change = max_param_change,
min_deriv_time = min_deriv_time,
raw_model = raw_model_string, context_opts = context_opts,
@ -387,7 +394,6 @@ def TrainNewModels(dir, iter, num_jobs, num_archives_processed, num_archives,
open('{0}/.error'.format(dir), 'w').close()
raise Exception("There was error during training iteration {0}".format(iter))
def TrainOneIteration(dir, iter, egs_dir,
num_jobs, num_archives_processed, num_archives,
learning_rate, shrinkage_value, num_chunk_per_minibatch,
@ -404,17 +410,21 @@ def TrainOneIteration(dir, iter, egs_dir,
if iter > 0:
ComputeProgress(dir, iter, egs_dir, run_opts)
# an option for writing cache (storing pairs of nnet-computations
# and computation-requests) during training.
cache_read_opt = ""
if iter > 0 and (iter <= (num_hidden_layers-1) * add_layers_period) and (iter % add_layers_period == 0):
do_average = False # if we've just mixed up, don't do averaging but take the
# best.
cur_num_hidden_layers = 1 + iter / add_layers_period
config_file = "{0}/configs/layer{1}.config".format(dir, cur_num_hidden_layers)
raw_model_string = "nnet3-am-copy --raw=true --learning-rate={lr} {dir}/{iter}.mdl - | nnet3-init --srand={iter} - {config} - |".format(lr=learning_rate, dir=dir, iter=iter, config=config_file )
raw_model_string = "nnet3-am-copy --raw=true --learning-rate={lr} {dir}/{iter}.mdl - | nnet3-init --srand={iter} - {config} - |".format(lr=learning_rate, dir=dir, iter=iter, config=config_file)
else:
do_average = True
if iter == 0:
do_average = False # on iteration 0, pick the best, don't average.
else:
cache_read_opt = "--read-cache={dir}/cache.{iter}".format(dir=dir, iter=iter)
raw_model_string = "nnet3-am-copy --raw=true --learning-rate={0} {1}/{2}.mdl - |".format(learning_rate, dir, iter)
if do_average:
@ -437,7 +447,7 @@ def TrainOneIteration(dir, iter, egs_dir,
left_context, right_context, min_deriv_time,
momentum, max_param_change,
shuffle_buffer_size, cur_num_chunk_per_minibatch,
run_opts)
cache_read_opt, run_opts)
[models_to_average, best_model] = GetSuccessfulModels(num_jobs, '{0}/log/train.{1}.%.log'.format(dir,iter))
nnets_list = []
for n in models_to_average:
@ -477,6 +487,12 @@ nnet3-am-copy --scale={shrink} --set-raw-nnet=- {dir}/{iter}.mdl {dir}/{new_iter
raise Exception("Could not find {0}, at the end of iteration {1}".format(new_model, iter))
elif os.stat(new_model).st_size == 0:
raise Exception("{0} has size 0. Something went wrong in iteration {1}".format(new_model, iter))
try:
if cache_read_opt:
os.remove("{dir}/cache.{iter}".format(dir=dir, iter=iter))
except OSError:
raise Exception("Error while trying to delete the cache file")
# args is a Namespace with the required parameters
def Train(args, run_opts):

Просмотреть файл

@ -39,20 +39,9 @@ for f in phones.txt words.txt L.fst L_disambig.fst phones/; do
done
lm_base=$(basename $lm '.gz')
gunzip -c $lm | utils/find_arpa_oovs.pl $out_dir/words.txt \
> $out_dir/oovs_${lm_base}.txt
# Removing all "illegal" combinations of <s> and </s>, which are supposed to
# occur only at being/end of utt. These can cause determinization failures
# of CLG [ends up being epsilon cycles].
gunzip -c $lm \
| egrep -v '<s> <s>|</s> <s>|</s> </s>' \
| arpa2fst - | fstprint \
| utils/remove_oovs.pl $out_dir/oovs_${lm_base}.txt \
| utils/eps2disambig.pl | utils/s2eps.pl \
| fstcompile --isymbols=$out_dir/words.txt --osymbols=$out_dir/words.txt \
--keep_isymbols=false --keep_osymbols=false \
| fstrmepsilon | fstarcsort --sort_type=ilabel > $out_dir/G.fst
arpa2fst --disambig-symbol=#0 \
--read-symbol-table=$out_dir/words.txt - $out_dir/G.fst
set +e
fstisstochastic $out_dir/G.fst
set -e
@ -66,7 +55,7 @@ set -e
# this might cause determinization failure of CLG.
# #0 is treated as an empty word.
mkdir -p $out_dir/tmpdir.g
awk '{if(NF==1){ printf("0 0 %s %s\n", $1,$1); }}
awk '{if(NF==1){ printf("0 0 %s %s\n", $1,$1); }}
END{print "0 0 #0 #0"; print "0";}' \
< "$lexicon" > $out_dir/tmpdir.g/select_empty.fst.txt

Просмотреть файл

@ -44,7 +44,7 @@ struct NGramTestData {
float backoff;
};
std::ostream& operator<<(std::ostream& os, const NGramTestData& data) {
std::ostream& operator<<(std::ostream &os, const NGramTestData &data) {
std::ios::fmtflags saved_state(os.flags());
os << std::fixed << std::setprecision(6);
@ -62,7 +62,7 @@ template <class T>
struct CountedArray {
template <size_t N>
CountedArray(T(&array)[N]) : array(array), count(N) { }
const T* array;
const T *array;
const size_t count;
};
@ -73,7 +73,7 @@ inline CountedArray<T> MakeCountedArray(T(&array)[N]) {
class TestableArpaFileParser : public ArpaFileParser {
public:
TestableArpaFileParser(ArpaParseOptions options, fst::SymbolTable* symbols)
TestableArpaFileParser(ArpaParseOptions options, fst::SymbolTable *symbols)
: ArpaFileParser(options, symbols),
header_available_(false),
read_complete_(false),
@ -120,9 +120,10 @@ void TestableArpaFileParser::ReadComplete() {
read_complete_ = true;
}
//
bool CompareNgrams(const NGramTestData& actual,
const NGramTestData& expected) {
bool CompareNgrams(const NGramTestData &actual,
NGramTestData expected) {
expected.logprob *= Log(10.0);
expected.backoff *= Log(10.0);
if (actual.line_number != expected.line_number
|| !std::equal(actual.words, actual.words + kMaxOrder,
expected.words)
@ -164,33 +165,33 @@ ngram 2=2\n\
ngram 3=2\n\
\n\
\\1-grams:\n\
-5.234679 4 -3.3\n\
-3.456783 5\n\
0.0000000 1 -2.5\n\
-4.333333 2\n\
-5.2 4 -3.3\n\
-3.4 5\n\
0 1 -2.5\n\
-4.3 2\n\
\n\
\\2-grams:\n\
-1.45678 4 5 -3.23\n\
-1.30490 1 4 -4.2\n\
-1.4 4 5 -3.2\n\
-1.3 1 4 -4.2\n\
\n\
\\3-grams:\n\
-0.34958 1 4 5\n\
-0.23940 4 5 2\n\
-0.3 1 4 5\n\
-0.2 4 5 2\n\
\n\
\\end\\";
int32 expect_counts[] = { 4, 2, 2 };
NGramTestData expect_ngrams[] = {
{ 7, -12.05329, { 4, 0, 0 }, -7.598531 },
{ 8, -7.959537, { 5, 0, 0 }, 0.0 },
{ 9, 0.0, { 1, 0, 0 }, -5.756463 },
{ 10, -9.977868, { 2, 0, 0 }, 0.0 },
{ 7, -5.2, { 4, 0, 0 }, -3.3 },
{ 8, -3.4, { 5, 0, 0 }, 0.0 },
{ 9, 0.0, { 1, 0, 0 }, -2.5 },
{ 10, -4.3, { 2, 0, 0 }, 0.0 },
{ 13, -3.354360, { 4, 5, 0 }, -7.437350 },
{ 14, -3.004643, { 1, 4, 0 }, -9.670857 },
{ 13, -1.4, { 4, 5, 0 }, -3.2 },
{ 14, -1.3, { 1, 4, 0 }, -4.2 },
{ 17, -0.804938, { 1, 4, 5 }, 0.0 },
{ 18, -0.551239, { 4, 5, 2 }, 0.0 } };
{ 17, -0.3, { 1, 4, 5 }, 0.0 },
{ 18, -0.2, { 4, 5, 2 }, 0.0 } };
ArpaParseOptions options;
options.bos_symbol = 1;
@ -231,7 +232,6 @@ ngram 3=2\n\
\\3-grams:\n\
-0.3 <s> a \xCE\xB2\n\
-0.2 <s> a </s>\n\
\n\
\\end\\";
// Symbol table that is created with predefined test symbols, "a" but no "b".
@ -270,7 +270,6 @@ void ReadSymbolicLmNoOovImpl(ArpaParseOptions::OovHandling oov) {
options.bos_symbol = 1;
options.eos_symbol = 2;
options.unk_symbol = 3;
options.use_log10 = true;
options.oov_handling = oov;
TestableArpaFileParser parser(options, &symbols);
std::istringstream stm(symbolic_lm, std::ios_base::in);
@ -301,7 +300,6 @@ void ReadSymbolicLmWithOovImpl(
options.bos_symbol = 1;
options.eos_symbol = 2;
options.unk_symbol = 3;
options.use_log10 = true;
options.oov_handling = oov;
TestableArpaFileParser parser(options, symbols);
std::istringstream stm(symbolic_lm, std::ios_base::in);

Просмотреть файл

@ -31,7 +31,8 @@ namespace kaldi {
ArpaFileParser::ArpaFileParser(ArpaParseOptions options,
fst::SymbolTable* symbols)
: options_(options), symbols_(symbols), line_number_(0) {
: options_(options), symbols_(symbols),
line_number_(0), warning_count_(0) {
}
ArpaFileParser::~ArpaFileParser() {
@ -70,49 +71,51 @@ void ArpaFileParser::Read(std::istream &is, bool binary) {
ngram_counts_.clear();
line_number_ = 0;
warning_count_ = 0;
current_line_.clear();
#define PARSE_ERR (KALDI_ERR << "in line " << line_number_ << ": ")
#define PARSE_ERR (KALDI_ERR << LineReference() << ": ")
// Give derived class an opportunity to prepare its state.
ReadStarted();
std::string line;
// Processes "\data\" section.
bool keyword_found = false;
while (++line_number_, getline(is, line) && !is.eof()) {
if (line.empty()) continue;
while (++line_number_, getline(is, current_line_) && !is.eof()) {
if (current_line_.empty()) continue;
// Continue skipping lines until the \data\ marker alone on a line is found.
if (!keyword_found) {
if (line == "\\data\\") {
if (current_line_ == "\\data\\") {
KALDI_LOG << "Reading \\data\\ section.";
keyword_found = true;
}
continue;
}
if (line[0] == '\\') break;
if (current_line_[0] == '\\') break;
// Enters "\data\" section, and looks for patterns like "ngram 1=1000",
// which means there are 1000 unigrams.
std::size_t equal_symbol_pos = line.find("=");
std::size_t equal_symbol_pos = current_line_.find("=");
if (equal_symbol_pos != std::string::npos)
line.replace(equal_symbol_pos, 1, " = "); // Inserts spaces around "="
// Guaranteed spaces around the "=".
current_line_.replace(equal_symbol_pos, 1, " = ");
std::vector<std::string> col;
SplitStringToVector(line, " \t", true, &col);
SplitStringToVector(current_line_, " \t", true, &col);
if (col.size() == 4 && col[0] == "ngram" && col[2] == "=") {
int32 order, ngram_count = 0;
if (!ConvertStringToInteger(col[1], &order) ||
!ConvertStringToInteger(col[3], &ngram_count)) {
PARSE_ERR << "Cannot parse ngram count '" << line << "'.";
PARSE_ERR << "cannot parse ngram count";
}
if (ngram_counts_.size() <= order) {
ngram_counts_.resize(order);
}
ngram_counts_[order - 1] = ngram_count;
} else {
KALDI_WARN << "Uninterpretable line in \\data\\ section: " << line;
KALDI_WARN << LineReference()
<< ": uninterpretable line in \\data\\ section";
}
}
@ -136,41 +139,38 @@ void ArpaFileParser::Read(std::istream &is, bool binary) {
// Must be looking at a \k-grams: directive at this point.
std::ostringstream keyword;
keyword << "\\" << cur_order << "-grams:";
if (line != keyword.str()) {
PARSE_ERR << "Invalid directive '" << line << "', "
<< "expecting '" << keyword.str() << "'.";
if (current_line_ != keyword.str()) {
PARSE_ERR << "invalid directive, expecting '" << keyword.str() << "'";
}
KALDI_LOG << "Reading " << line << " section.";
KALDI_LOG << "Reading " << current_line_ << " section.";
int32 ngram_count = 0;
while (++line_number_, getline(is, line) && !is.eof()) {
if (line.empty()) continue;
if (line[0] == '\\') break;
while (++line_number_, getline(is, current_line_) && !is.eof()) {
if (current_line_.empty()) continue;
if (current_line_[0] == '\\') break;
std::vector<std::string> col;
SplitStringToVector(line, " \t", true, &col);
SplitStringToVector(current_line_, " \t", true, &col);
if (col.size() < 1 + cur_order ||
col.size() > 2 + cur_order ||
(cur_order == ngram_counts_.size() && col.size() != 1 + cur_order)) {
PARSE_ERR << "Invalid n-gram line '" << line << "'";
PARSE_ERR << "Invalid n-gram data line";
}
++ngram_count;
// Parse out n-gram logprob and, if present, backoff weight.
if (!ConvertStringToReal(col[0], &ngram.logprob)) {
PARSE_ERR << "Invalid n-gram logprob '" << col[0] << "'.";
PARSE_ERR << "invalid n-gram logprob '" << col[0] << "'";
}
ngram.backoff = 0.0;
if (col.size() > cur_order + 1) {
if (!ConvertStringToReal(col[cur_order + 1], &ngram.backoff))
PARSE_ERR << "Invalid backoff weight '" << col[cur_order + 1] << "'.";
}
// Convert to natural log unless the option is set not to.
if (!options_.use_log10) {
ngram.logprob *= M_LN10;
ngram.backoff *= M_LN10;
PARSE_ERR << "invalid backoff weight '" << col[cur_order + 1] << "'";
}
// Convert to natural log.
ngram.logprob *= M_LN10;
ngram.backoff *= M_LN10;
ngram.words.resize(cur_order);
bool skip_ngram = false;
@ -188,11 +188,14 @@ void ArpaFileParser::Read(std::istream &is, bool binary) {
word = options_.unk_symbol;
break;
case ArpaParseOptions::kSkipNGram:
if (ShouldWarn())
KALDI_WARN << LineReference() << " skipped: word '"
<< col[1 + index] << "' not in symbol table";
skip_ngram = true;
break;
default:
PARSE_ERR << "Word '" << col[1 + index]
<< "' not in symbol table.";
PARSE_ERR << "word '" << col[1 + index]
<< "' not in symbol table";
}
}
}
@ -204,8 +207,8 @@ void ArpaFileParser::Read(std::istream &is, bool binary) {
}
// Whichever way we got it, an epsilon is invalid.
if (word == 0) {
PARSE_ERR << "Epsilon symbol '" << col[1 + index]
<< "' is illegal in ARPA LM.";
PARSE_ERR << "epsilon symbol '" << col[1 + index]
<< "' is illegal in ARPA LM";
}
ngram.words[index] = word;
}
@ -214,20 +217,36 @@ void ArpaFileParser::Read(std::istream &is, bool binary) {
}
}
if (ngram_count > ngram_counts_[cur_order - 1]) {
PARSE_ERR << "Header said there would be " << ngram_counts_[cur_order - 1]
<< " n-grams of order " << cur_order << ", but we saw "
<< ngram_count;
PARSE_ERR << "header said there would be " << ngram_counts_[cur_order - 1]
<< " n-grams of order " << cur_order
<< ", but we saw more already.";
}
}
if (line != "\\end\\") {
PARSE_ERR << "Invalid or unexpected directive line '" << line << "', "
<< "expected \\end\\.";
if (current_line_ != "\\end\\") {
PARSE_ERR << "invalid or unexpected directive line, expecting \\end\\";
}
if (warning_count_ > 0 && warning_count_ > (uint32)options_.max_warnings) {
KALDI_WARN << "Of " << warning_count_ << " parse warnings, "
<< options_.max_warnings << " were reported. Run program with "
<< "--max_warnings=-1 to see all warnings";
}
current_line_.empty();
ReadComplete();
#undef PARSE_ERR
}
std::string ArpaFileParser::LineReference() const {
std::stringstream ss;
ss << "line " << line_number_ << " [" << current_line_ << "]";
return ss.str();
}
bool ArpaFileParser::ShouldWarn() {
return ++warning_count_ <= (uint32)options_.max_warnings;
}
} // namespace kaldi

Просмотреть файл

@ -27,6 +27,7 @@
#include <fst/fst-decl.h>
#include "base/kaldi-types.h"
#include "itf/options-itf.h"
namespace kaldi {
@ -43,13 +44,22 @@ struct ArpaParseOptions {
ArpaParseOptions()
: bos_symbol(-1), eos_symbol(-1), unk_symbol(-1),
oov_handling(kRaiseError), use_log10(false) { }
oov_handling(kRaiseError), max_warnings(30) { }
void Register(OptionsItf *opts) {
// Registering only the max_warnings count, since other options are
// treated differently by client programs: some want integer symbols,
// while other are passed words in their command line.
opts->Register("max-arpa-warnings", &max_warnings,
"Maximum warnings to report on ARPA parsing, "
"0 to disable, -1 to show all");
}
int32 bos_symbol; ///< Symbol for <s>, Required non-epsilon.
int32 eos_symbol; ///< Symbol for </s>, Required non-epsilon.
int32 unk_symbol; ///< Symbol for <unk>, Required for kReplaceWithUnk.
OovHandling oov_handling; ///< How to handle OOV words in the file.
bool use_log10; ///< Use log10 for prob and backoff weight, not ln.
int32 max_warnings; ///< Maximum warnings to report, <0 unlimited.
};
/**
@ -111,6 +121,14 @@ class ArpaFileParser {
/// Inside ConsumeNGram(), provides the current line number.
int32 LineNumber() const { return line_number_; }
/// Inside ConsumeNGram(), returns a formatted reference to the line being
/// compiled, to print out as part of diagnostics.
std::string LineReference() const;
/// Increments warning count, and returns true if a warning should be
/// printed or false if the count has exceeded the set maximum.
bool ShouldWarn();
/// N-gram counts. Valid in and after a call to HeaderAvailable().
const std::vector<int32>& NgramCounts() const { return ngram_counts_; }
@ -118,6 +136,8 @@ class ArpaFileParser {
ArpaParseOptions options_;
fst::SymbolTable* symbols_; // Not owned.
int32 line_number_;
uint32 warning_count_;
std::string current_line_;
std::vector<int32> ngram_counts_;
};

Просмотреть файл

@ -106,18 +106,21 @@ class OptimizedHistKey {
uint64 data_;
};
} // namespace
template <class HistKey>
class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface {
public:
ArpaLmCompilerImpl(fst::StdVectorFst* fst, Symbol bos_symbol,
Symbol eos_symbol, Symbol sub_eps);
ArpaLmCompilerImpl(ArpaLmCompiler* parent, fst::StdVectorFst* fst,
Symbol sub_eps);
virtual void ConsumeNGram(const NGram& ngram, bool is_highest);
virtual void ConsumeNGram(const NGram &ngram, bool is_highest);
private:
StateId AddStateWithBackoff(HistKey key, float backoff);
void CreateBackoff(HistKey key, StateId state, float weight);
ArpaLmCompiler *parent_; // Not owned.
fst::StdVectorFst* fst_; // Not owned.
Symbol bos_symbol_;
Symbol eos_symbol_;
@ -131,10 +134,9 @@ class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface {
template <class HistKey>
ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(
fst::StdVectorFst* fst, Symbol bos_symbol,
Symbol eos_symbol, Symbol sub_eps) : fst_(fst), bos_symbol_(bos_symbol),
eos_symbol_(eos_symbol),
sub_eps_(sub_eps) {
ArpaLmCompiler* parent, fst::StdVectorFst* fst, Symbol sub_eps)
: parent_(parent), fst_(fst), bos_symbol_(parent->Options().bos_symbol),
eos_symbol_(parent->Options().eos_symbol), sub_eps_(sub_eps) {
// The algorithm maintains state per history. The 0-gram is a special state
// for emptry history. All unigrams (including BOS) backoff into this state.
StateId zerogram = fst_->AddState();
@ -150,8 +152,8 @@ ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(
}
template <class HistKey>
void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(
const NGram& ngram, bool is_highest) {
void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram &ngram,
bool is_highest) {
// Generally, we do the following. Suppose we are adding an n-gram "A B
// C". Then find the node for "A B", add a new node for "A B C", and connect
// them with the arc accepting "C" with the specified weight. Also, add a
@ -181,7 +183,9 @@ void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(
if (source_it == history_.end()) {
// There was no "A B", therefore the probability of "A B C" is zero.
// Print a warning and discard current n-gram.
KALDI_WARN << "No parent gram, skipping";
if (parent_->ShouldWarn())
KALDI_WARN << parent_->LineReference()
<< " skipped: no parent (n-1)-gram exists";
return;
}
@ -225,6 +229,7 @@ void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(
// Add arc from source to dest, whichever way it was found.
fst_->AddArc(source, fst::StdArc(sym, sym, weight, dest));
return;
}
// Find or create a new state for n-gram defined by key, and ensure it has a
@ -266,8 +271,6 @@ inline void ArpaLmCompilerImpl<HistKey>::CreateBackoff(
fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second));
}
} // namespace
ArpaLmCompiler::~ArpaLmCompiler() {
if (impl_ != NULL)
delete impl_;
@ -286,25 +289,24 @@ void ArpaLmCompiler::HeaderAvailable() {
max_symbol += NgramCounts()[0];
if (NgramCounts().size() <= 4 && max_symbol < OptimizedHistKey::kMaxData) {
impl_ = new ArpaLmCompilerImpl<OptimizedHistKey>(
&fst_, Options().bos_symbol, Options().eos_symbol, sub_eps_);
impl_ = new ArpaLmCompilerImpl<OptimizedHistKey>(this, &fst_, sub_eps_);
} else {
impl_ = new ArpaLmCompilerImpl<GeneralHistKey>(
&fst_, Options().bos_symbol, Options().eos_symbol, sub_eps_);
impl_ = new ArpaLmCompilerImpl<GeneralHistKey>(this, &fst_, sub_eps_);
KALDI_LOG << "Reverting to slower state tracking because model is large: "
<< NgramCounts().size() << "-gram with symbols up to "
<< max_symbol;
}
}
void ArpaLmCompiler::ConsumeNGram(const NGram& ngram) {
void ArpaLmCompiler::ConsumeNGram(const NGram &ngram) {
// <s> is invalid in tails, </s> in heads of an n-gram.
for (int i = 0; i < ngram.words.size(); ++i) {
if ((i > 0 && ngram.words[i] == Options().bos_symbol) ||
(i + 1 < ngram.words.size()
&& ngram.words[i] == Options().eos_symbol)) {
KALDI_WARN << "In line " << LineNumber()
<< ": Skipping n-gram with invalid BOS/EOS placement";
if (ShouldWarn())
KALDI_WARN << LineReference()
<< " skipped: n-gram has invalid BOS/EOS placement";
return;
}
}

Просмотреть файл

@ -39,6 +39,7 @@ class ArpaLmCompiler : public ArpaFileParser {
~ArpaLmCompiler();
const fst::StdVectorFst& Fst() const { return fst_; }
fst::StdVectorFst* MutableFst() { return &fst_; }
protected:
// ArpaFileParser overrides.
@ -50,6 +51,7 @@ class ArpaLmCompiler : public ArpaFileParser {
int sub_eps_;
ArpaLmCompilerImplInterface* impl_; // Owned.
fst::StdVectorFst fst_;
template <class HistKey> friend class ArpaLmCompilerImpl;
};
} // namespace kaldi

Просмотреть файл

@ -268,7 +268,7 @@ void ConstArpaLmBuilder::HeaderAvailable() {
ngram_order_ = NgramCounts().size();
}
void ConstArpaLmBuilder::ConsumeNGram(const NGram& ngram) {
void ConstArpaLmBuilder::ConsumeNGram(const NGram &ngram) {
int32 cur_order = ngram.words.size();
// If <ngram_order_> is larger than 1, then we do not create LmState for
// the final order entry. We only keep the log probability for it.
@ -1062,15 +1062,9 @@ bool ConstArpaLmDeterministicFst::GetArc(StateId s,
return true;
}
bool BuildConstArpaLm(const int32 bos_symbol, const int32 eos_symbol,
const int32 unk_symbol,
bool BuildConstArpaLm(const ArpaParseOptions& options,
const std::string& arpa_rxfilename,
const std::string& const_arpa_wxfilename) {
ArpaParseOptions options;
options.bos_symbol = bos_symbol;
options.eos_symbol = eos_symbol;
options.unk_symbol = unk_symbol;
ConstArpaLmBuilder lm_builder(options);
KALDI_LOG << "Reading " << arpa_rxfilename;
ReadKaldiObject(arpa_rxfilename, &lm_builder);

Просмотреть файл

@ -25,6 +25,7 @@
#include "base/kaldi-common.h"
#include "fstext/deterministic-fst.h"
#include "lm/arpa-file-parser.h"
#include "util/common-utils.h"
namespace kaldi {
@ -418,8 +419,7 @@ class ConstArpaLmDeterministicFst
// Reads in an Arpa format language model and converts it into ConstArpaLm
// format. We assume that the words in the input Arpa format language model have
// been converted into integers.
bool BuildConstArpaLm(const int32 bos_symbol, const int32 eos_symbol,
const int32 unk_symbol,
bool BuildConstArpaLm(const ArpaParseOptions& options,
const std::string& arpa_rxfilename,
const std::string& const_arpa_wxfilename);

Просмотреть файл

@ -13,7 +13,7 @@
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// MERCHANTABILITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
@ -33,7 +33,7 @@ int main(int argc, char *argv[]) {
"that wants to rescore lattices. We assume that the words in the\n"
"input arpa language model has been converted to integers.\n"
"\n"
"The program is used joinly with utils/map_arpa_lm.pl to build\n"
"The program is used jointly with utils/map_arpa_lm.pl to build\n"
"ConstArpaLm format language model. We first map the words in an Arpa\n"
"format language model to integers using utils/map_arpa_m.pl, and\n"
"then use this program to build a ConstArpaLm format language model.\n"
@ -44,16 +44,19 @@ int main(int argc, char *argv[]) {
kaldi::ParseOptions po(usage);
int32 unk_symbol = -1;
int32 bos_symbol = -1;
int32 eos_symbol = -1;
po.Register("unk-symbol", &unk_symbol,
ArpaParseOptions options;
options.Register(&po);
// Ideally, these registrations would be in ArpaParseOptions, but some
// programs want integers and other want symbols, so we register them
// outside instead.
po.Register("unk-symbol", &options.unk_symbol,
"Integer corresponds to unknown-word in language model. -1 if "
"no such word is provided.");
po.Register("bos-symbol", &bos_symbol,
po.Register("bos-symbol", &options.bos_symbol,
"Integer corresponds to <s>. You must set this to your actual "
"BOS integer.");
po.Register("eos-symbol", &eos_symbol,
po.Register("eos-symbol", &options.eos_symbol,
"Integer corresponds to </s>. You must set this to your actual "
"EOS integer.");
@ -64,7 +67,7 @@ int main(int argc, char *argv[]) {
exit(1);
}
if (bos_symbol == -1 || eos_symbol == -1) {
if (options.bos_symbol == -1 || options.eos_symbol == -1) {
KALDI_ERR << "Please set --bos-symbol and --eos-symbol.";
exit(1);
}
@ -72,8 +75,8 @@ int main(int argc, char *argv[]) {
std::string arpa_rxfilename = po.GetArg(1),
const_arpa_wxfilename = po.GetOptArg(2);
bool ans = BuildConstArpaLm(bos_symbol, eos_symbol, unk_symbol,
arpa_rxfilename, const_arpa_wxfilename);
bool ans = BuildConstArpaLm(options, arpa_rxfilename,
const_arpa_wxfilename);
if (ans)
return 0;
else

Просмотреть файл

@ -34,6 +34,9 @@ int main(int argc, char *argv[]) {
"data/lang/words.txt lm/input.arpa G.fst\n";
ParseOptions po(usage);
ArpaParseOptions options;
options.Register(&po);
// Option flags.
std::string bos_symbol = "<s>";
std::string eos_symbol = "</s>";
@ -41,6 +44,7 @@ int main(int argc, char *argv[]) {
std::string read_syms_filename;
std::string write_syms_filename;
bool keep_symbols = false;
bool ilabel_sort = true;
po.Register("bos-symbol", &bos_symbol,
"Beginning of sentence symbol");
@ -56,6 +60,8 @@ int main(int argc, char *argv[]) {
po.Register("keep-symbols", &keep_symbols,
"Store symbol table with FST. Forced true if "
"symbol tables are neiter read or written");
po.Register("ilabel-sort", &ilabel_sort,
"Ilabel-sort the output FST");
po.Read(argc, argv);
@ -66,7 +72,6 @@ int main(int argc, char *argv[]) {
std::string arpa_rxfilename = po.GetArg(1),
fst_wxfilename = po.GetOptArg(2);
ArpaParseOptions options;
int64 disambig_symbol_id = 0;
fst::SymbolTable* symbols;
@ -110,6 +115,11 @@ int main(int argc, char *argv[]) {
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
ReadKaldiObject(arpa_rxfilename, &lm_compiler);
// Sort the FST in-place if requested by options.
if (ilabel_sort) {
fst::ArcSort(lm_compiler.MutableFst(), fst::StdILabelCompare());
}
// Write symbols if requested.
if (!write_syms_filename.empty()) {
kaldi::Output kosym(write_syms_filename, false);

Просмотреть файл

@ -220,8 +220,9 @@ inline void Component::Propagate(const CuMatrixBase<BaseFloat> &in,
CuMatrix<BaseFloat> *out) {
// Check the dims
if (input_dim_ != in.NumCols()) {
KALDI_ERR << "Non-matching dims! " << TypeToMarker(GetType())
<< " input-dim : " << input_dim_ << " data : " << in.NumCols();
KALDI_ERR << "Non-matching dims on the input of " << TypeToMarker(GetType())
<< " component. The input-dim is " << input_dim_
<< ", the data had " << in.NumCols() << " dims.";
}
// Allocate target buffer
out->Resize(in.NumRows(), output_dim_, kSetZero); // reset

Просмотреть файл

@ -41,7 +41,7 @@ namespace nnet1 {
class SimpleSentenceAveragingComponent : public Component {
public:
SimpleSentenceAveragingComponent(int32 dim_in, int32 dim_out)
: Component(dim_in, dim_out), gradient_boost_(100.0)
: Component(dim_in, dim_out), gradient_boost_(100.0), shrinkage_(0.0), only_summing_(false)
{ }
~SimpleSentenceAveragingComponent()
{ }
@ -56,7 +56,9 @@ class SimpleSentenceAveragingComponent : public Component {
while (!is.eof()) {
ReadToken(is, false, &token);
if (token == "<GradientBoost>") ReadBasicType(is, false, &gradient_boost_);
else KALDI_ERR << "Unknown token " << token << ", a typo in config? (GradientBoost)";
else if (token == "<Shrinkage>") ReadBasicType(is, false, &shrinkage_);
else if (token == "<OnlySumming>") ReadBasicType(is, false, &only_summing_);
else KALDI_ERR << "Unknown token " << token << ", a typo in config? (GradientBoost|Shrinkage|OnlySumming)";
is >> std::ws; // eat-up whitespace
}
}
@ -64,15 +66,29 @@ class SimpleSentenceAveragingComponent : public Component {
void ReadData(std::istream &is, bool binary) {
ExpectToken(is, binary, "<GradientBoost>");
ReadBasicType(is, binary, &gradient_boost_);
if(PeekToken(is, binary) == 'S') {
ExpectToken(is, binary, "<Shrinkage>");
ReadBasicType(is, binary, &shrinkage_);
}
if(PeekToken(is, binary) == 'O') {
ExpectToken(is, binary, "<OnlySumming>");
ReadBasicType(is, binary, &only_summing_);
}
}
void WriteData(std::ostream &os, bool binary) const {
WriteToken(os, binary, "<GradientBoost>");
WriteBasicType(os, binary, gradient_boost_);
WriteToken(os, binary, "<Shrinkage>");
WriteBasicType(os, binary, shrinkage_);
WriteToken(os, binary, "<OnlySumming>");
WriteBasicType(os, binary, only_summing_);
}
std::string Info() const {
return std::string("\n gradient-boost ") + ToString(gradient_boost_);
return std::string("\n gradient-boost ") + ToString(gradient_boost_) +
", shrinkage: " + ToString(shrinkage_) +
", only summing: " + ToString(only_summing_);
}
std::string InfoGradient() const {
return Info();
@ -81,7 +97,11 @@ class SimpleSentenceAveragingComponent : public Component {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
// get the average row-vector,
average_row_.Resize(InputDim());
average_row_.AddRowSumMat(1.0/in.NumRows(), in, 0.0);
if (only_summing_) {
average_row_.AddRowSumMat(1.0, in, 0.0);
} else {
average_row_.AddRowSumMat(1.0/(in.NumRows()+shrinkage_), in, 0.0);
}
// copy it on the output,
out->AddVecToRows(1.0, average_row_, 0.0);
}
@ -97,7 +117,11 @@ class SimpleSentenceAveragingComponent : public Component {
//
// getting the average output diff,
average_diff_.Resize(OutputDim());
average_diff_.AddRowSumMat(1.0/out_diff.NumRows(), out_diff, 0.0);
if (only_summing_) {
average_diff_.AddRowSumMat(1.0, out_diff, 0.0);
} else {
average_diff_.AddRowSumMat(1.0/(out_diff.NumRows()+shrinkage_), out_diff, 0.0);
}
// copy the derivative into the input diff, (applying gradient-boost!!)
in_diff->AddVecToRows(gradient_boost_, average_diff_, 0.0);
}
@ -106,6 +130,9 @@ class SimpleSentenceAveragingComponent : public Component {
CuVector<BaseFloat> average_row_; ///< auxiliary buffer for forward propagation,
CuVector<BaseFloat> average_diff_; ///< auxiliary buffer for backpropagation,
BaseFloat gradient_boost_; ///< increase of gradient applied in backpropagation,
BaseFloat shrinkage_; ///< Number of 'imaginary' zero-vectors in the average
///< (shrinks the average vector for shorter sentences),
bool only_summing_; ///< Removes normalization term from arithmetic mean (when true).
};

Просмотреть файл

@ -515,7 +515,7 @@ void CachingOptimizingCompiler::ReadCache(std::istream &is, bool binary) {
bool read_cache = (opt_config_ == opt_config_cached);
if (read_cache) {
size_t computation_cache_size;
int32 computation_cache_size;
ExpectToken(is, binary, "<ComputationCacheSize>");
ReadBasicType(is, binary, &computation_cache_size);
KALDI_ASSERT(computation_cache_size >= 0);