зеркало из https://github.com/mozilla/kaldi.git
Merge branch 'master' into chain
This commit is contained in:
Коммит
ce708ea167
|
@ -1,4 +1,4 @@
|
|||
#!/bin/bash
|
||||
#!/bin/bash
|
||||
#
|
||||
|
||||
if [ -f path.sh ]; then . path.sh; fi
|
||||
|
@ -15,25 +15,12 @@ arpa_lm=$1
|
|||
|
||||
cp -r data/lang data/lang_test
|
||||
|
||||
# grep -v '<s> <s>' etc. is only for future-proofing this script. Our
|
||||
# LM doesn't have these "invalid combinations". These can cause
|
||||
# determinization failures of CLG [ends up being epsilon cycles].
|
||||
# Note: remove_oovs.pl takes a list of words in the LM that aren't in
|
||||
# our word list. Since our LM doesn't have any, we just give it
|
||||
# /dev/null [we leave it in the script to show how you'd do it].
|
||||
gunzip -c "$arpa_lm" | \
|
||||
grep -v '<s> <s>' | \
|
||||
grep -v '</s> <s>' | \
|
||||
grep -v '</s> </s>' | \
|
||||
arpa2fst - | fstprint | \
|
||||
utils/remove_oovs.pl /dev/null | \
|
||||
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=data/lang_test/words.txt \
|
||||
--osymbols=data/lang_test/words.txt --keep_isymbols=false --keep_osymbols=false | \
|
||||
fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang_test/G.fst
|
||||
fstisstochastic data/lang_test/G.fst
|
||||
arpa2fst --disambig-symbol=#0 \
|
||||
--read-symbol-table=data/lang_test/words.txt - data/lang_test/G.fst
|
||||
|
||||
echo "Checking how stochastic G is (the first of these numbers should be small):"
|
||||
fstisstochastic data/lang_test/G.fst
|
||||
fstisstochastic data/lang_test/G.fst
|
||||
|
||||
## Check lexicon.
|
||||
## just have a look and make sure it seems sane.
|
||||
|
@ -61,4 +48,3 @@ fsttablecompose data/lang/L_disambig.fst data/lang_test/G.fst | \
|
|||
fstisstochastic || echo LG is not stochastic
|
||||
|
||||
echo AMI_format_data succeeded.
|
||||
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
Example scripts on how to use a pre-trained chain enlgish model and kaldi base code to recognize any number of wav files.
|
||||
|
||||
IMPORTANT: wav files must be in 16kHz, 16 bit little-endian format.
|
||||
|
||||
Model:
|
||||
English pretrained model were released by Api.ai under Creative Commons Attribution-ShareAlike 4.0 International Public License.
|
||||
- Acustic data is mostly mobile recorded data
|
||||
- Language model is based on Assistant.ai logs and good for understanding short commands, like "Wake me up at 7 am"
|
||||
For more details, visit https://github.com/api-ai/api-ai-english-asr-model
|
||||
|
||||
Usage:
|
||||
- Ensure kaldi is compiled and this scripts are inside kaldi/egs/<subfolder>/ directory
|
||||
- Run ./download-model.sh - to download pretrained chain model
|
||||
- Run ./recognize-wav.sh test1.wav test2.wav to do recognition
|
||||
- See output for recognition results
|
||||
|
||||
Using steps/nnet3/decode.sh script:
|
||||
You can use kaldi steps/nnet3/decode.sh, which will decode data and calculate Word Error Rate (WER) for it.
|
||||
Steps:
|
||||
- Run recognize-wav.sh test1.wav test2.wav, it will make data dir, calculate mfcc features for it and do decoding, you need only first two steps out of it
|
||||
- If you want WER then edit data/test-corpus/text and replace NO_TRANSCRIPTION with expected text transcription for every wav file
|
||||
- steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --cmd run.pl --nj 1 exp/api.ai-model/ data/test-corpus/ exp/api.ai-model/decode/
|
||||
- See exp/api.ai-model/decode/wer* files for WER and exp/api.ai-model/decode/log/ files for decoding output
|
||||
|
||||
Online Decoder:
|
||||
At the moment kaldi does not support online decoding for nnet3 models, but decoders can be found here https://github.com/api-ai/kaldi/ .
|
||||
See http://kaldi.sourceforge.net/online_decoding.html for more information about kaldi online decoding.
|
||||
Steps:
|
||||
- Run ./local/create-corpus.sh data/test-corpus/ test1.wav test2.wav (or just run recognize-wav.sh) to create corpus
|
||||
- If you want WER then edit data/test-corpus/text and replace NO_TRANSCRIPTION with expected text transcription for every wav file
|
||||
- Make config file exp/api.ai-model/online.conf with following content
|
||||
==CONTENT START==
|
||||
--feature-type=mfcc
|
||||
--mfcc-config=exp/api.ai-model/mfcc.conf
|
||||
==CONTENT END==
|
||||
- Run steps/online/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --cmd run.pl --nj 1 exp/api.ai-model/ data/test-corpus/ exp/api.ai-model/decode/
|
||||
- See exp/api.ai-model/decode/wer* files for WER and exp/api.ai-model/decode/log/ files for decoding output
|
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
# Downlaods Api.ai chain model into exp/api.ai-model/ (will replace one if exists)
|
||||
|
||||
DOWNLOAD_URL="https://api.ai/downloads/api.ai-kaldi-asr-model.zip"
|
||||
|
||||
echo "Downloading model"
|
||||
wget -N $DOWNLOAD_URL || ( echo "Unable to download model: $DOWNLOAD_URL" && exit 1 );
|
||||
|
||||
echo "Unpacking model"
|
||||
unzip api.ai-kaldi-asr-model.zip || ( echo "Unable to extract api.ai-kaldi-asr-model.zip" && exit 1 );
|
||||
|
||||
echo "Moving model to exp/api.ai-model/"
|
||||
if [ ! -d exp ]; then
|
||||
mkdir exp;
|
||||
fi;
|
||||
|
||||
if [ -d exp/api.ai-model ]; then
|
||||
echo "Found existing model, removing";
|
||||
rm -rf exp/api.ai-model/
|
||||
fi
|
||||
|
||||
mv api.ai-kaldi-asr-model exp/api.ai-model || ( echo "Unable to move model to exp/" && exit 1 )
|
||||
|
||||
echo "Model is ready to use use recognize-wav.sh to do voice recognition"
|
|
@ -0,0 +1,50 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Checking arguments
|
||||
if [ $# -le 1 ]; then
|
||||
echo "Use $0 <datadir> test1.wav [test2.wav] ..."
|
||||
echo " $0 data/test-corpus test1.wav test2.wav"
|
||||
exit 0;
|
||||
fi
|
||||
|
||||
CORPUS=$1
|
||||
shift
|
||||
for file in "$@"; do
|
||||
if [[ "$file" != *.wav ]]; then
|
||||
echo "Expecting .wav files, got $file"
|
||||
exit 1;
|
||||
fi
|
||||
|
||||
if [ ! -f "$file" ]; then
|
||||
echo "$file not found";
|
||||
exit 1;
|
||||
fi
|
||||
done;
|
||||
|
||||
|
||||
echo "Initilizing $CORPUS"
|
||||
if [ ! -d "$CORPUS" ]; then
|
||||
echo "Creating $CORPUS directory"
|
||||
mkdir -p "$CORPUS" || ( echo "Unable to create data dir" && exit 1 )
|
||||
fi;
|
||||
|
||||
wav_scp="$CORPUS/wav.scp"
|
||||
spk2utt="$CORPUS/spk2utt"
|
||||
utt2spk="$CORPUS/utt2spk"
|
||||
text="$CORPUS/text"
|
||||
|
||||
#nulling files
|
||||
cat </dev/null >$wav_scp
|
||||
cat </dev/null >$spk2utt
|
||||
cat </dev/null >$utt2spk
|
||||
cat </dev/null >$text
|
||||
rm $CORPUS/feats.scp 2>/dev/null;
|
||||
rm $CORPUS/cmvn.scp 2>/dev/null;
|
||||
|
||||
for file in "$@"; do
|
||||
id=$(echo $file | sed -e 's/ /_/g')
|
||||
echo "$id $file" >>$wav_scp
|
||||
echo "$id $id" >>$spk2utt
|
||||
echo "$id $id" >>$utt2spk
|
||||
echo "$id NO_TRANSRIPTION" >>$text
|
||||
done;
|
|
@ -0,0 +1 @@
|
|||
../../wsj/s5/local/score.sh
|
|
@ -0,0 +1,6 @@
|
|||
export KALDI_ROOT=`pwd`/../../..
|
||||
[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh
|
||||
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH
|
||||
[ ! -f $KALDI_ROOT/src/path.sh ] && echo >&2 "The standard file $KALDI_ROOT/src/path.sh is not present -> Exit!" && exit 1
|
||||
. $KALDI_ROOT/src/path.sh
|
||||
export LC_ALL=C
|
|
@ -0,0 +1,53 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2016 Api.ai (Author: Ilya Platonov)
|
||||
# Apache 2.0
|
||||
|
||||
# This script demonstrates kaldi decoding using pretrained model. It will decode list of wav files.
|
||||
#
|
||||
# IMPORTANT: wav files must be in 16kHz, 16 bit little-endian format.
|
||||
#
|
||||
# This script tries to follow with what other scripts are doing in terms of directory structures and data handling.
|
||||
#
|
||||
# Use ./download-model.sh script to download asr model
|
||||
# See https://github.com/api-ai/api-ai-english-asr-model for details about a model and how to use it.
|
||||
|
||||
. path.sh
|
||||
MODEL_DIR="exp/api.ai-model"
|
||||
DATA_DIR="data/test-corpus"
|
||||
|
||||
echo "///////"
|
||||
echo "// IMPORTANT: wav files must be in 16kHz, 16 bit little-endian format."
|
||||
echo "//////";
|
||||
|
||||
for file in final.mdl HCLG.fst words.txt frame_subsampling_factor; do
|
||||
if [ ! -f $MODEL_DIR/$file ]; then
|
||||
echo "$MODEL_DIR/$file not found, use ./download-model.sh"
|
||||
exit 1;
|
||||
fi
|
||||
done;
|
||||
|
||||
for app in nnet3-latgen-faster apply-cmvn lattice-scale; do
|
||||
command -v $app >/dev/null 2>&1 || { echo >&2 "$app not found, is kaldi compiled?"; exit 1; }
|
||||
done;
|
||||
|
||||
local/create-corpus.sh $DATA_DIR $@ || exit 1;
|
||||
|
||||
echo "///////"
|
||||
echo "// Computing mfcc and cmvn (cmvn is not really used)"
|
||||
echo "//////";
|
||||
|
||||
steps/make_mfcc.sh --nj 1 --mfcc-config $MODEL_DIR/mfcc.conf \
|
||||
--cmd "run.pl" $DATA_DIR exp/make_mfcc exp/mfcc || { echo "Unable to calculate mfcc, ensure 16kHz, 16 bit little-endian wav format or see log"; exit 1; };
|
||||
steps/compute_cmvn_stats.sh $DATA_DIR exp/make_mfcc/ exp/mfcc || exit 1;
|
||||
|
||||
echo "///////"
|
||||
echo "// Doing decoding (see log for results)"
|
||||
echo "//////";
|
||||
frame_subsampling_factor=$(cat $MODEL_DIR/frame_subsampling_factor)
|
||||
nnet3-latgen-faster --frame-subsampling-factor=$frame_subsampling_factor --frames-per-chunk=50 --extra-left-context=0 \
|
||||
--extra-right-context=0 --extra-left-context-initial=-1 --extra-right-context-final=-1 \
|
||||
--minimize=false --max-active=7000 --min-active=200 --beam=15.0 --lattice-beam=8.0 \
|
||||
--acoustic-scale=1.0 --allow-partial=true \
|
||||
--word-symbol-table=$MODEL_DIR/words.txt $MODEL_DIR/final.mdl $MODEL_DIR//HCLG.fst \
|
||||
"ark,s,cs:apply-cmvn --norm-means=false --norm-vars=false --utt2spk=ark:$DATA_DIR/utt2spk scp:$DATA_DIR/cmvn.scp scp:$DATA_DIR/feats.scp ark:- |" \
|
||||
"ark:|lattice-scale --acoustic-scale=10.0 ark:- ark:- >exp/lat.1"
|
|
@ -0,0 +1 @@
|
|||
../../wsj/s5/steps/
|
|
@ -0,0 +1 @@
|
|||
../../wsj/s5/utils/
|
|
@ -18,40 +18,23 @@ fi
|
|||
|
||||
lm_dir=$1
|
||||
|
||||
tmpdir=data/local/lm_tmp
|
||||
lexicon=data/local/lang_tmp/lexiconp.txt
|
||||
mkdir -p $tmpdir
|
||||
|
||||
# This loop was taken verbatim from wsj_format_data.sh, and I'm leaving it in place in
|
||||
# case we decide to add more language models at some point
|
||||
for lm_suffix in tgpr; do
|
||||
test=data/lang_test_${lm_suffix}
|
||||
mkdir -p $test
|
||||
for f in phones.txt words.txt phones.txt L.fst L_disambig.fst phones oov.txt oov.int; do
|
||||
for f in phones.txt words.txt phones.txt L.fst L_disambig.fst phones topo oov.txt oov.int; do
|
||||
cp -r data/lang/$f $test
|
||||
done
|
||||
gunzip -c $lm_dir/lm_${lm_suffix}.arpa.gz |\
|
||||
utils/find_arpa_oovs.pl $test/words.txt > $tmpdir/oovs_${lm_suffix}.txt || exit 1
|
||||
|
||||
# grep -v '<s> <s>' because the LM seems to have some strange and useless
|
||||
# stuff in it with multiple <s>'s in the history. Encountered some other similar
|
||||
# things in a LM from Geoff. Removing all "illegal" combinations of <s> and </s>,
|
||||
# which are supposed to occur only at being/end of utt. These can cause
|
||||
# determinization failures of CLG [ends up being epsilon cycles].
|
||||
gunzip -c $lm_dir/lm_${lm_suffix}.arpa.gz | \
|
||||
grep -v '<s> <s>' | \
|
||||
grep -v '</s> <s>' | \
|
||||
grep -v '</s> </s>' | \
|
||||
arpa2fst - | fstprint | \
|
||||
utils/remove_oovs.pl $tmpdir/oovs_${lm_suffix}.txt | \
|
||||
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$test/words.txt \
|
||||
--osymbols=$test/words.txt --keep_isymbols=false --keep_osymbols=false | \
|
||||
fstrmepsilon | fstarcsort --sort_type=ilabel > $test/G.fst
|
||||
arpa2fst --disambig-symbol=#0 \
|
||||
--read-symbol-table=$test/words.txt - $test/G.fst
|
||||
|
||||
utils/validate_lang.pl $test || exit 1;
|
||||
done
|
||||
|
||||
echo "Succeeded in formatting data."
|
||||
rm -r $tmpdir
|
||||
|
||||
exit 0
|
||||
|
|
|
@ -49,24 +49,9 @@ for lm_suffix in tgsmall tgmed; do
|
|||
test=${src_dir}_test_${lm_suffix}
|
||||
mkdir -p $test
|
||||
cp -r ${src_dir}/* $test
|
||||
gunzip -c $lm_dir/lm_${lm_suffix}.arpa.gz |\
|
||||
utils/find_arpa_oovs.pl $test/words.txt > $tmpdir/oovs_${lm_suffix}.txt || exit 1
|
||||
|
||||
# grep -v '<s> <s>' because the LM seems to have some strange and useless
|
||||
# stuff in it with multiple <s>'s in the history. Encountered some other
|
||||
# similar things in a LM from Geoff. Removing all "illegal" combinations of
|
||||
# <s> and </s>, which are supposed to occur only at being/end of utt. These
|
||||
# can cause determinization failures of CLG [ends up being epsilon cycles].
|
||||
gunzip -c $lm_dir/lm_${lm_suffix}.arpa.gz | \
|
||||
grep -v '<s> <s>' | \
|
||||
grep -v '</s> <s>' | \
|
||||
grep -v '</s> </s>' | \
|
||||
arpa2fst - | fstprint | \
|
||||
utils/remove_oovs.pl $tmpdir/oovs_${lm_suffix}.txt | \
|
||||
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$test/words.txt \
|
||||
--osymbols=$test/words.txt --keep_isymbols=false --keep_osymbols=false | \
|
||||
fstrmepsilon | fstarcsort --sort_type=ilabel > $test/G.fst
|
||||
|
||||
arpa2fst --disambig-symbol=#0 \
|
||||
--read-symbol-table=$test/words.txt - $test/G.fst
|
||||
utils/validate_lang.pl --skip-determinization-check $test || exit 1;
|
||||
done
|
||||
|
||||
|
|
|
@ -40,10 +40,10 @@ open(DUR3, ">$dir/3sec") || die "Failed opening output file $dir/3sec";
|
|||
open(DUR10, ">$dir/10sec") || die "Failed opening output file $dir/10sec";
|
||||
open(DUR30, ">$dir/30sec") || die "Failed opening output file $dir/30sec";
|
||||
|
||||
my $key_str = `wget -qO- "http://www.itl.nist.gov/iad/mig/tests/lang/2007/lid07key_v5.txt"`;
|
||||
my $key_str = `wget -qO- "http://www.openslr.org/resources/23/lre07_key.txt"`;
|
||||
@key_lines = split("\n",$key_str);
|
||||
%utt2lang = ();
|
||||
%utt2dur = ();
|
||||
%utt2lang = ();
|
||||
%utt2dur = ();
|
||||
foreach (@key_lines) {
|
||||
@words = split(' ', $_);
|
||||
if (index($words[0], "#") == -1) {
|
||||
|
|
|
@ -27,7 +27,7 @@ tmpdir=data/local/lm_tmp
|
|||
lexicon=data/local/lang${lang_suffix}_tmp/lexiconp.txt
|
||||
mkdir -p $tmpdir
|
||||
|
||||
for x in train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do
|
||||
for x in train_si284 test_eval92 test_eval93 test_dev93 test_eval92_5k test_eval93_5k test_dev93_5k dev_dt_05 dev_dt_20; do
|
||||
mkdir -p data/$x
|
||||
cp $srcdir/${x}_wav.scp data/$x/wav.scp || exit 1;
|
||||
cp $srcdir/$x.txt data/$x/text || exit 1;
|
||||
|
@ -49,22 +49,8 @@ for lm_suffix in bg tgpr tg bg_5k tgpr_5k tg_5k; do
|
|||
cp -r data/lang${lang_suffix}/* $test || exit 1;
|
||||
|
||||
gunzip -c $lmdir/lm_${lm_suffix}.arpa.gz | \
|
||||
utils/find_arpa_oovs.pl $test/words.txt > $tmpdir/oovs_${lm_suffix}.txt
|
||||
|
||||
# grep -v '<s> <s>' because the LM seems to have some strange and useless
|
||||
# stuff in it with multiple <s>'s in the history. Encountered some other similar
|
||||
# things in a LM from Geoff. Removing all "illegal" combinations of <s> and </s>,
|
||||
# which are supposed to occur only at being/end of utt. These can cause
|
||||
# determinization failures of CLG [ends up being epsilon cycles].
|
||||
gunzip -c $lmdir/lm_${lm_suffix}.arpa.gz | \
|
||||
grep -v '<s> <s>' | \
|
||||
grep -v '</s> <s>' | \
|
||||
grep -v '</s> </s>' | \
|
||||
arpa2fst - | fstprint | \
|
||||
utils/remove_oovs.pl $tmpdir/oovs_${lm_suffix}.txt | \
|
||||
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$test/words.txt \
|
||||
--osymbols=$test/words.txt --keep_isymbols=false --keep_osymbols=false | \
|
||||
fstrmepsilon | fstarcsort --sort_type=ilabel > $test/G.fst
|
||||
arpa2fst --disambig-symbol=#0 \
|
||||
--read-symbol-table=$test/words.txt - $test/G.fst
|
||||
|
||||
utils/validate_lang.pl --skip-determinization-check $test || exit 1;
|
||||
done
|
||||
|
|
|
@ -45,17 +45,13 @@ fi
|
|||
# Be careful: this time we dispense with the grep -v '<s> <s>' so this might
|
||||
# not work for LMs generated from all toolkits.
|
||||
gunzip -c $lm_srcdir_3g/lm_pr6.0.gz | \
|
||||
arpa2fst - | fstprint | \
|
||||
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$lang/words.txt \
|
||||
--osymbols=$lang/words.txt --keep_isymbols=false --keep_osymbols=false | \
|
||||
fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang${lang_suffix}_test_bd_tgpr/G.fst || exit 1;
|
||||
arpa2fst --disambig-symbol=#0 \
|
||||
--read-symbol-table=$lang/words.txt - data/lang${lang_suffix}_test_bd_tgpr/G.fst || exit 1;
|
||||
fstisstochastic data/lang${lang_suffix}_test_bd_tgpr/G.fst
|
||||
|
||||
gunzip -c $lm_srcdir_3g/lm_unpruned.gz | \
|
||||
arpa2fst - | fstprint | \
|
||||
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$lang/words.txt \
|
||||
--osymbols=$lang/words.txt --keep_isymbols=false --keep_osymbols=false | \
|
||||
fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang${lang_suffix}_test_bd_tg/G.fst || exit 1;
|
||||
arpa2fst --disambig-symbol=#0 \
|
||||
--read-symbol-table=$lang/words.txt - data/lang${lang_suffix}_test_bd_tg/G.fst || exit 1;
|
||||
fstisstochastic data/lang${lang_suffix}_test_bd_tg/G.fst
|
||||
|
||||
# Build ConstArpaLm for the unpruned language model.
|
||||
|
@ -65,10 +61,8 @@ gunzip -c $lm_srcdir_3g/lm_unpruned.gz | \
|
|||
--unk-symbol=$unk - data/lang${lang_suffix}_test_bd_tgconst/G.carpa || exit 1
|
||||
|
||||
gunzip -c $lm_srcdir_4g/lm_unpruned.gz | \
|
||||
arpa2fst - | fstprint | \
|
||||
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$lang/words.txt \
|
||||
--osymbols=$lang/words.txt --keep_isymbols=false --keep_osymbols=false | \
|
||||
fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang${lang_suffix}_test_bd_fg/G.fst || exit 1;
|
||||
arpa2fst --disambig-symbol=#0 \
|
||||
--read-symbol-table=$lang/words.txt - data/lang${lang_suffix}_test_bd_fg/G.fst || exit 1;
|
||||
fstisstochastic data/lang${lang_suffix}_test_bd_fg/G.fst
|
||||
|
||||
# Build ConstArpaLm for the unpruned language model.
|
||||
|
@ -78,10 +72,8 @@ gunzip -c $lm_srcdir_4g/lm_unpruned.gz | \
|
|||
--unk-symbol=$unk - data/lang${lang_suffix}_test_bd_fgconst/G.carpa || exit 1
|
||||
|
||||
gunzip -c $lm_srcdir_4g/lm_pr7.0.gz | \
|
||||
arpa2fst - | fstprint | \
|
||||
utils/eps2disambig.pl | utils/s2eps.pl | fstcompile --isymbols=$lang/words.txt \
|
||||
--osymbols=$lang/words.txt --keep_isymbols=false --keep_osymbols=false | \
|
||||
fstrmepsilon | fstarcsort --sort_type=ilabel > data/lang${lang_suffix}_test_bd_fgpr/G.fst || exit 1;
|
||||
arpa2fst --disambig-symbol=#0 \
|
||||
--read-symbol-table=$lang/words.txt - data/lang${lang_suffix}_test_bd_fgpr/G.fst || exit 1;
|
||||
fstisstochastic data/lang${lang_suffix}_test_bd_fgpr/G.fst
|
||||
|
||||
exit 0;
|
||||
|
|
|
@ -42,11 +42,11 @@ Allowed options:
|
|||
(default = "<***>")
|
||||
--wer-cutoff : Ignore segments with WER higher than the specified value.
|
||||
-1 means no segment will be ignored. (default = -1)
|
||||
--use-silence-midpoints : Set to 1 if you want to use silence midpoints
|
||||
--use-silence-midpoints : Set to 1 if you want to use silence midpoints
|
||||
instead of min_sil_length for silence overhang.(default 0)
|
||||
--force-correct-boundary-words : Set to zero if the segments will not be
|
||||
--force-correct-boundary-words : Set to zero if the segments will not be
|
||||
required to have boundary words to be correct. Default 1
|
||||
--aligned-ctm-filename : If set, the intermediate aligned ctm
|
||||
--aligned-ctm-filename : If set, the intermediate aligned ctm
|
||||
is saved to this file
|
||||
EOU
|
||||
|
||||
|
@ -56,7 +56,7 @@ my $min_sil_length = 0.5;
|
|||
my $separator = ";";
|
||||
my $special_symbol = "<***>";
|
||||
my $wer_cutoff = -1;
|
||||
my $use_silence_midpoints = 0;
|
||||
my $use_silence_midpoints = 0;
|
||||
my $force_correct_boundary_words = 1;
|
||||
my $aligned_ctm_filename = "";
|
||||
GetOptions(
|
||||
|
@ -122,13 +122,13 @@ sub PrintSegment {
|
|||
|
||||
# Works out the surrounding silence.
|
||||
my $index = $seg_start_index - 1;
|
||||
while ($index >= 0 && $aligned_ctm->[$index]->[0] eq
|
||||
while ($index >= 0 && $aligned_ctm->[$index]->[0] eq
|
||||
"<eps>" && $aligned_ctm->[$index]->[3] == 0) {
|
||||
$index -= 1;
|
||||
}
|
||||
my $left_of_segment_has_deletion = "false";
|
||||
$left_of_segment_has_deletion = "true"
|
||||
if ($index > 0 && $aligned_ctm->[$index-1]->[0] ne "<eps>"
|
||||
my $left_of_segment_has_deletion = "false";
|
||||
$left_of_segment_has_deletion = "true"
|
||||
if ($index > 0 && $aligned_ctm->[$index-1]->[0] ne "<eps>"
|
||||
&& $aligned_ctm->[$index-1]->[3] == 0);
|
||||
|
||||
my $pad_start_sil = ($aligned_ctm->[$seg_start_index]->[1] -
|
||||
|
@ -141,11 +141,11 @@ sub PrintSegment {
|
|||
my $right_of_segment_has_deletion = "false";
|
||||
$index = $seg_end_index + 1;
|
||||
while ($index < scalar(@{$aligned_ctm}) &&
|
||||
$aligned_ctm->[$index]->[0] eq "<eps>" &&
|
||||
$aligned_ctm->[$index]->[0] eq "<eps>" &&
|
||||
$aligned_ctm->[$index]->[3] == 0) {
|
||||
$index += 1;
|
||||
}
|
||||
$right_of_segment_has_deletion = "true"
|
||||
$right_of_segment_has_deletion = "true"
|
||||
if ($index < scalar(@{$aligned_ctm})-1 && $aligned_ctm->[$index+1]->[0] ne
|
||||
"<eps>" && $aligned_ctm->[$index - 1]->[3] > 0);
|
||||
my $pad_end_sil = ($aligned_ctm->[$index - 1]->[1] +
|
||||
|
@ -155,7 +155,7 @@ sub PrintSegment {
|
|||
if (($right_of_segment_has_deletion eq "true") || !$use_silence_midpoints) {
|
||||
if ($pad_end_sil > $min_sil_length / 2.0) {
|
||||
$pad_end_sil = $min_sil_length / 2.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
my $seg_start = $aligned_ctm->[$seg_start_index]->[1] - $pad_start_sil;
|
||||
|
@ -270,6 +270,7 @@ sub ProcessWav {
|
|||
$current_ctm, $current_align, $SO, $TO, $ACT) = @_;
|
||||
|
||||
my $wav_id = $current_ctm->[0]->[0];
|
||||
my $channel_id = $current_ctm->[0]->[1];
|
||||
defined($wav_id) || die "Error: empty wav section\n";
|
||||
|
||||
# First, we have to align the ctm file to the Levenshtein alignment.
|
||||
|
@ -324,8 +325,8 @@ sub ProcessWav {
|
|||
|
||||
# Save the aligned CTM if needed
|
||||
if(defined($ACT)){
|
||||
for (my $i=0; $i<=$#aligned_ctm; $i++) {
|
||||
print $ACT "$aligned_ctm[$i][0] $aligned_ctm[$i][1] ";
|
||||
for (my $i = 0; $i <= $#aligned_ctm; $i++) {
|
||||
print $ACT "$wav_id $channel_id $aligned_ctm[$i][0] $aligned_ctm[$i][1] ";
|
||||
print $ACT "$aligned_ctm[$i][2] $aligned_ctm[$i][3]\n";
|
||||
}
|
||||
}
|
||||
|
@ -346,8 +347,8 @@ sub ProcessWav {
|
|||
# length, and if there are no alignment error around it. We also make sure
|
||||
# that segment contains actual words, instead of pure silence.
|
||||
if ($aligned_ctm[$x]->[0] eq "<eps>" &&
|
||||
$aligned_ctm[$x]->[2] >= $min_sil_length
|
||||
&& (($force_correct_boundary_words && $lcorrect eq "true" &&
|
||||
$aligned_ctm[$x]->[2] >= $min_sil_length
|
||||
&& (($force_correct_boundary_words && $lcorrect eq "true" &&
|
||||
$rcorrect eq "true") || !$force_correct_boundary_words)) {
|
||||
if ($current_seg_length <= $max_seg_length &&
|
||||
$current_seg_length >= $min_seg_length) {
|
||||
|
@ -379,7 +380,7 @@ sub ProcessWav {
|
|||
# 011 A 3.39 0.23 SELL
|
||||
# 011 A 3.62 0.18 OFF
|
||||
# 011 A 3.83 0.45 ASSETS
|
||||
#
|
||||
#
|
||||
# Output ctm:
|
||||
# 011 A 3.39 0.23 SELL
|
||||
# 011 A 3.62 0.18 OFF
|
||||
|
|
|
@ -111,10 +111,8 @@ while read line; do
|
|||
if (invoc[$x]) { printf("%s ", $x); } else { printf("%s ", oov); } }
|
||||
printf("\n"); }' > $wdir/text
|
||||
ngram-count -text $wdir/text -order $ngram_order "$srilm_options" -lm - |\
|
||||
arpa2fst - | fstprint | utils/eps2disambig.pl | utils/s2eps.pl |\
|
||||
fstcompile --isymbols=$lang/words.txt --osymbols=$lang/words.txt \
|
||||
--keep_isymbols=false --keep_osymbols=false |\
|
||||
fstrmepsilon | fstarcsort --sort_type=ilabel > $wdir/G.fst || exit 1;
|
||||
arpa2fst --disambig-symbol=#0 \
|
||||
--read-symbol-table=$lang/words.txt - $wdir/G.fst || exit 1;
|
||||
fi
|
||||
fstisstochastic $wdir/G.fst || echo "$0: $uttid/G.fst not stochastic."
|
||||
|
||||
|
@ -134,7 +132,7 @@ while read line; do
|
|||
|
||||
make-h-transducer --disambig-syms-out=$wdir/disambig_tid.int \
|
||||
--transition-scale=$tscale $wdir/ilabels_${N}_${P} \
|
||||
$model_dir/tree $model_dir/final.mdl > $wdir/Ha.fst
|
||||
$model_dir/tree $model_dir/final.mdl > $wdir/Ha.fst
|
||||
|
||||
# Builds HCLGa.fst
|
||||
fsttablecompose $wdir/Ha.fst $wdir/CLG.fst | \
|
||||
|
@ -143,10 +141,10 @@ while read line; do
|
|||
fstminimizeencoded > $wdir/HCLGa.fst
|
||||
fstisstochastic $wdir/HCLGa.fst ||\
|
||||
echo "$0: $uttid/HCLGa.fst is not stochastic"
|
||||
|
||||
|
||||
add-self-loops --self-loop-scale=$loopscale --reorder=true \
|
||||
$model_dir/final.mdl < $wdir/HCLGa.fst > $wdir/HCLG.fst
|
||||
|
||||
|
||||
if [ $tscale == 1.0 -a $loopscale == 1.0 ]; then
|
||||
fstisstochastic $wdir/HCLG.fst ||\
|
||||
echo "$0: $uttid/HCLG.fst is not stochastic."
|
||||
|
|
|
@ -338,7 +338,7 @@ def TrainNewModels(dir, iter, num_jobs, num_archives_processed, num_archives,
|
|||
left_context, right_context, min_deriv_time,
|
||||
momentum, max_param_change,
|
||||
shuffle_buffer_size, num_chunk_per_minibatch,
|
||||
run_opts):
|
||||
cache_read_opt, run_opts):
|
||||
# We cannot easily use a single parallel SGE job to do the main training,
|
||||
# because the computation of which archive and which --frame option
|
||||
# to use for each job is a little complex, so we spawn each one separately.
|
||||
|
@ -353,9 +353,15 @@ def TrainNewModels(dir, iter, num_jobs, num_archives_processed, num_archives,
|
|||
# the other indexes from.
|
||||
archive_index = (k % num_archives) + 1 # work out the 1-based archive index.
|
||||
|
||||
cache_write_opt = ""
|
||||
if job == 1:
|
||||
# an option for writing cache (storing pairs of nnet-computations and
|
||||
# computation-requests) during training.
|
||||
cache_write_opt="--write-cache={dir}/cache.{iter}".format(dir=dir, iter=iter+1)
|
||||
|
||||
process_handle = RunKaldiCommand("""
|
||||
{command} {train_queue_opt} {dir}/log/train.{iter}.{job}.log \
|
||||
nnet3-train {parallel_train_opts} \
|
||||
nnet3-train {parallel_train_opts} {cache_read_opt} {cache_write_opt} \
|
||||
--print-interval=10 --momentum={momentum} \
|
||||
--max-param-change={max_param_change} \
|
||||
--optimization.min-deriv-time={min_deriv_time} "{raw_model}" \
|
||||
|
@ -365,6 +371,7 @@ def TrainNewModels(dir, iter, num_jobs, num_archives_processed, num_archives,
|
|||
train_queue_opt = run_opts.train_queue_opt,
|
||||
dir = dir, iter = iter, next_iter = iter + 1, job = job,
|
||||
parallel_train_opts = run_opts.parallel_train_opts,
|
||||
cache_read_opt = cache_read_opt, cache_write_opt = cache_write_opt,
|
||||
momentum = momentum, max_param_change = max_param_change,
|
||||
min_deriv_time = min_deriv_time,
|
||||
raw_model = raw_model_string, context_opts = context_opts,
|
||||
|
@ -387,7 +394,6 @@ def TrainNewModels(dir, iter, num_jobs, num_archives_processed, num_archives,
|
|||
open('{0}/.error'.format(dir), 'w').close()
|
||||
raise Exception("There was error during training iteration {0}".format(iter))
|
||||
|
||||
|
||||
def TrainOneIteration(dir, iter, egs_dir,
|
||||
num_jobs, num_archives_processed, num_archives,
|
||||
learning_rate, shrinkage_value, num_chunk_per_minibatch,
|
||||
|
@ -404,17 +410,21 @@ def TrainOneIteration(dir, iter, egs_dir,
|
|||
if iter > 0:
|
||||
ComputeProgress(dir, iter, egs_dir, run_opts)
|
||||
|
||||
# an option for writing cache (storing pairs of nnet-computations
|
||||
# and computation-requests) during training.
|
||||
cache_read_opt = ""
|
||||
if iter > 0 and (iter <= (num_hidden_layers-1) * add_layers_period) and (iter % add_layers_period == 0):
|
||||
|
||||
do_average = False # if we've just mixed up, don't do averaging but take the
|
||||
# best.
|
||||
cur_num_hidden_layers = 1 + iter / add_layers_period
|
||||
config_file = "{0}/configs/layer{1}.config".format(dir, cur_num_hidden_layers)
|
||||
raw_model_string = "nnet3-am-copy --raw=true --learning-rate={lr} {dir}/{iter}.mdl - | nnet3-init --srand={iter} - {config} - |".format(lr=learning_rate, dir=dir, iter=iter, config=config_file )
|
||||
raw_model_string = "nnet3-am-copy --raw=true --learning-rate={lr} {dir}/{iter}.mdl - | nnet3-init --srand={iter} - {config} - |".format(lr=learning_rate, dir=dir, iter=iter, config=config_file)
|
||||
else:
|
||||
do_average = True
|
||||
if iter == 0:
|
||||
do_average = False # on iteration 0, pick the best, don't average.
|
||||
else:
|
||||
cache_read_opt = "--read-cache={dir}/cache.{iter}".format(dir=dir, iter=iter)
|
||||
raw_model_string = "nnet3-am-copy --raw=true --learning-rate={0} {1}/{2}.mdl - |".format(learning_rate, dir, iter)
|
||||
|
||||
if do_average:
|
||||
|
@ -437,7 +447,7 @@ def TrainOneIteration(dir, iter, egs_dir,
|
|||
left_context, right_context, min_deriv_time,
|
||||
momentum, max_param_change,
|
||||
shuffle_buffer_size, cur_num_chunk_per_minibatch,
|
||||
run_opts)
|
||||
cache_read_opt, run_opts)
|
||||
[models_to_average, best_model] = GetSuccessfulModels(num_jobs, '{0}/log/train.{1}.%.log'.format(dir,iter))
|
||||
nnets_list = []
|
||||
for n in models_to_average:
|
||||
|
@ -477,6 +487,12 @@ nnet3-am-copy --scale={shrink} --set-raw-nnet=- {dir}/{iter}.mdl {dir}/{new_iter
|
|||
raise Exception("Could not find {0}, at the end of iteration {1}".format(new_model, iter))
|
||||
elif os.stat(new_model).st_size == 0:
|
||||
raise Exception("{0} has size 0. Something went wrong in iteration {1}".format(new_model, iter))
|
||||
try:
|
||||
if cache_read_opt:
|
||||
os.remove("{dir}/cache.{iter}".format(dir=dir, iter=iter))
|
||||
except OSError:
|
||||
raise Exception("Error while trying to delete the cache file")
|
||||
|
||||
|
||||
# args is a Namespace with the required parameters
|
||||
def Train(args, run_opts):
|
||||
|
|
|
@ -39,20 +39,9 @@ for f in phones.txt words.txt L.fst L_disambig.fst phones/; do
|
|||
done
|
||||
|
||||
lm_base=$(basename $lm '.gz')
|
||||
gunzip -c $lm | utils/find_arpa_oovs.pl $out_dir/words.txt \
|
||||
> $out_dir/oovs_${lm_base}.txt
|
||||
|
||||
# Removing all "illegal" combinations of <s> and </s>, which are supposed to
|
||||
# occur only at being/end of utt. These can cause determinization failures
|
||||
# of CLG [ends up being epsilon cycles].
|
||||
gunzip -c $lm \
|
||||
| egrep -v '<s> <s>|</s> <s>|</s> </s>' \
|
||||
| arpa2fst - | fstprint \
|
||||
| utils/remove_oovs.pl $out_dir/oovs_${lm_base}.txt \
|
||||
| utils/eps2disambig.pl | utils/s2eps.pl \
|
||||
| fstcompile --isymbols=$out_dir/words.txt --osymbols=$out_dir/words.txt \
|
||||
--keep_isymbols=false --keep_osymbols=false \
|
||||
| fstrmepsilon | fstarcsort --sort_type=ilabel > $out_dir/G.fst
|
||||
arpa2fst --disambig-symbol=#0 \
|
||||
--read-symbol-table=$out_dir/words.txt - $out_dir/G.fst
|
||||
set +e
|
||||
fstisstochastic $out_dir/G.fst
|
||||
set -e
|
||||
|
@ -66,7 +55,7 @@ set -e
|
|||
# this might cause determinization failure of CLG.
|
||||
# #0 is treated as an empty word.
|
||||
mkdir -p $out_dir/tmpdir.g
|
||||
awk '{if(NF==1){ printf("0 0 %s %s\n", $1,$1); }}
|
||||
awk '{if(NF==1){ printf("0 0 %s %s\n", $1,$1); }}
|
||||
END{print "0 0 #0 #0"; print "0";}' \
|
||||
< "$lexicon" > $out_dir/tmpdir.g/select_empty.fst.txt
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ struct NGramTestData {
|
|||
float backoff;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const NGramTestData& data) {
|
||||
std::ostream& operator<<(std::ostream &os, const NGramTestData &data) {
|
||||
std::ios::fmtflags saved_state(os.flags());
|
||||
os << std::fixed << std::setprecision(6);
|
||||
|
||||
|
@ -62,7 +62,7 @@ template <class T>
|
|||
struct CountedArray {
|
||||
template <size_t N>
|
||||
CountedArray(T(&array)[N]) : array(array), count(N) { }
|
||||
const T* array;
|
||||
const T *array;
|
||||
const size_t count;
|
||||
};
|
||||
|
||||
|
@ -73,7 +73,7 @@ inline CountedArray<T> MakeCountedArray(T(&array)[N]) {
|
|||
|
||||
class TestableArpaFileParser : public ArpaFileParser {
|
||||
public:
|
||||
TestableArpaFileParser(ArpaParseOptions options, fst::SymbolTable* symbols)
|
||||
TestableArpaFileParser(ArpaParseOptions options, fst::SymbolTable *symbols)
|
||||
: ArpaFileParser(options, symbols),
|
||||
header_available_(false),
|
||||
read_complete_(false),
|
||||
|
@ -120,9 +120,10 @@ void TestableArpaFileParser::ReadComplete() {
|
|||
read_complete_ = true;
|
||||
}
|
||||
|
||||
//
|
||||
bool CompareNgrams(const NGramTestData& actual,
|
||||
const NGramTestData& expected) {
|
||||
bool CompareNgrams(const NGramTestData &actual,
|
||||
NGramTestData expected) {
|
||||
expected.logprob *= Log(10.0);
|
||||
expected.backoff *= Log(10.0);
|
||||
if (actual.line_number != expected.line_number
|
||||
|| !std::equal(actual.words, actual.words + kMaxOrder,
|
||||
expected.words)
|
||||
|
@ -164,33 +165,33 @@ ngram 2=2\n\
|
|||
ngram 3=2\n\
|
||||
\n\
|
||||
\\1-grams:\n\
|
||||
-5.234679 4 -3.3\n\
|
||||
-3.456783 5\n\
|
||||
0.0000000 1 -2.5\n\
|
||||
-4.333333 2\n\
|
||||
-5.2 4 -3.3\n\
|
||||
-3.4 5\n\
|
||||
0 1 -2.5\n\
|
||||
-4.3 2\n\
|
||||
\n\
|
||||
\\2-grams:\n\
|
||||
-1.45678 4 5 -3.23\n\
|
||||
-1.30490 1 4 -4.2\n\
|
||||
-1.4 4 5 -3.2\n\
|
||||
-1.3 1 4 -4.2\n\
|
||||
\n\
|
||||
\\3-grams:\n\
|
||||
-0.34958 1 4 5\n\
|
||||
-0.23940 4 5 2\n\
|
||||
-0.3 1 4 5\n\
|
||||
-0.2 4 5 2\n\
|
||||
\n\
|
||||
\\end\\";
|
||||
|
||||
int32 expect_counts[] = { 4, 2, 2 };
|
||||
NGramTestData expect_ngrams[] = {
|
||||
{ 7, -12.05329, { 4, 0, 0 }, -7.598531 },
|
||||
{ 8, -7.959537, { 5, 0, 0 }, 0.0 },
|
||||
{ 9, 0.0, { 1, 0, 0 }, -5.756463 },
|
||||
{ 10, -9.977868, { 2, 0, 0 }, 0.0 },
|
||||
{ 7, -5.2, { 4, 0, 0 }, -3.3 },
|
||||
{ 8, -3.4, { 5, 0, 0 }, 0.0 },
|
||||
{ 9, 0.0, { 1, 0, 0 }, -2.5 },
|
||||
{ 10, -4.3, { 2, 0, 0 }, 0.0 },
|
||||
|
||||
{ 13, -3.354360, { 4, 5, 0 }, -7.437350 },
|
||||
{ 14, -3.004643, { 1, 4, 0 }, -9.670857 },
|
||||
{ 13, -1.4, { 4, 5, 0 }, -3.2 },
|
||||
{ 14, -1.3, { 1, 4, 0 }, -4.2 },
|
||||
|
||||
{ 17, -0.804938, { 1, 4, 5 }, 0.0 },
|
||||
{ 18, -0.551239, { 4, 5, 2 }, 0.0 } };
|
||||
{ 17, -0.3, { 1, 4, 5 }, 0.0 },
|
||||
{ 18, -0.2, { 4, 5, 2 }, 0.0 } };
|
||||
|
||||
ArpaParseOptions options;
|
||||
options.bos_symbol = 1;
|
||||
|
@ -231,7 +232,6 @@ ngram 3=2\n\
|
|||
\\3-grams:\n\
|
||||
-0.3 <s> a \xCE\xB2\n\
|
||||
-0.2 <s> a </s>\n\
|
||||
\n\
|
||||
\\end\\";
|
||||
|
||||
// Symbol table that is created with predefined test symbols, "a" but no "b".
|
||||
|
@ -270,7 +270,6 @@ void ReadSymbolicLmNoOovImpl(ArpaParseOptions::OovHandling oov) {
|
|||
options.bos_symbol = 1;
|
||||
options.eos_symbol = 2;
|
||||
options.unk_symbol = 3;
|
||||
options.use_log10 = true;
|
||||
options.oov_handling = oov;
|
||||
TestableArpaFileParser parser(options, &symbols);
|
||||
std::istringstream stm(symbolic_lm, std::ios_base::in);
|
||||
|
@ -301,7 +300,6 @@ void ReadSymbolicLmWithOovImpl(
|
|||
options.bos_symbol = 1;
|
||||
options.eos_symbol = 2;
|
||||
options.unk_symbol = 3;
|
||||
options.use_log10 = true;
|
||||
options.oov_handling = oov;
|
||||
TestableArpaFileParser parser(options, symbols);
|
||||
std::istringstream stm(symbolic_lm, std::ios_base::in);
|
||||
|
|
|
@ -31,7 +31,8 @@ namespace kaldi {
|
|||
|
||||
ArpaFileParser::ArpaFileParser(ArpaParseOptions options,
|
||||
fst::SymbolTable* symbols)
|
||||
: options_(options), symbols_(symbols), line_number_(0) {
|
||||
: options_(options), symbols_(symbols),
|
||||
line_number_(0), warning_count_(0) {
|
||||
}
|
||||
|
||||
ArpaFileParser::~ArpaFileParser() {
|
||||
|
@ -70,49 +71,51 @@ void ArpaFileParser::Read(std::istream &is, bool binary) {
|
|||
|
||||
ngram_counts_.clear();
|
||||
line_number_ = 0;
|
||||
warning_count_ = 0;
|
||||
current_line_.clear();
|
||||
|
||||
#define PARSE_ERR (KALDI_ERR << "in line " << line_number_ << ": ")
|
||||
#define PARSE_ERR (KALDI_ERR << LineReference() << ": ")
|
||||
|
||||
// Give derived class an opportunity to prepare its state.
|
||||
ReadStarted();
|
||||
|
||||
std::string line;
|
||||
|
||||
// Processes "\data\" section.
|
||||
bool keyword_found = false;
|
||||
while (++line_number_, getline(is, line) && !is.eof()) {
|
||||
if (line.empty()) continue;
|
||||
while (++line_number_, getline(is, current_line_) && !is.eof()) {
|
||||
if (current_line_.empty()) continue;
|
||||
|
||||
// Continue skipping lines until the \data\ marker alone on a line is found.
|
||||
if (!keyword_found) {
|
||||
if (line == "\\data\\") {
|
||||
if (current_line_ == "\\data\\") {
|
||||
KALDI_LOG << "Reading \\data\\ section.";
|
||||
keyword_found = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (line[0] == '\\') break;
|
||||
if (current_line_[0] == '\\') break;
|
||||
|
||||
// Enters "\data\" section, and looks for patterns like "ngram 1=1000",
|
||||
// which means there are 1000 unigrams.
|
||||
std::size_t equal_symbol_pos = line.find("=");
|
||||
std::size_t equal_symbol_pos = current_line_.find("=");
|
||||
if (equal_symbol_pos != std::string::npos)
|
||||
line.replace(equal_symbol_pos, 1, " = "); // Inserts spaces around "="
|
||||
// Guaranteed spaces around the "=".
|
||||
current_line_.replace(equal_symbol_pos, 1, " = ");
|
||||
std::vector<std::string> col;
|
||||
SplitStringToVector(line, " \t", true, &col);
|
||||
SplitStringToVector(current_line_, " \t", true, &col);
|
||||
if (col.size() == 4 && col[0] == "ngram" && col[2] == "=") {
|
||||
int32 order, ngram_count = 0;
|
||||
if (!ConvertStringToInteger(col[1], &order) ||
|
||||
!ConvertStringToInteger(col[3], &ngram_count)) {
|
||||
PARSE_ERR << "Cannot parse ngram count '" << line << "'.";
|
||||
PARSE_ERR << "cannot parse ngram count";
|
||||
}
|
||||
if (ngram_counts_.size() <= order) {
|
||||
ngram_counts_.resize(order);
|
||||
}
|
||||
ngram_counts_[order - 1] = ngram_count;
|
||||
} else {
|
||||
KALDI_WARN << "Uninterpretable line in \\data\\ section: " << line;
|
||||
KALDI_WARN << LineReference()
|
||||
<< ": uninterpretable line in \\data\\ section";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -136,41 +139,38 @@ void ArpaFileParser::Read(std::istream &is, bool binary) {
|
|||
// Must be looking at a \k-grams: directive at this point.
|
||||
std::ostringstream keyword;
|
||||
keyword << "\\" << cur_order << "-grams:";
|
||||
if (line != keyword.str()) {
|
||||
PARSE_ERR << "Invalid directive '" << line << "', "
|
||||
<< "expecting '" << keyword.str() << "'.";
|
||||
if (current_line_ != keyword.str()) {
|
||||
PARSE_ERR << "invalid directive, expecting '" << keyword.str() << "'";
|
||||
}
|
||||
KALDI_LOG << "Reading " << line << " section.";
|
||||
KALDI_LOG << "Reading " << current_line_ << " section.";
|
||||
|
||||
int32 ngram_count = 0;
|
||||
while (++line_number_, getline(is, line) && !is.eof()) {
|
||||
if (line.empty()) continue;
|
||||
if (line[0] == '\\') break;
|
||||
while (++line_number_, getline(is, current_line_) && !is.eof()) {
|
||||
if (current_line_.empty()) continue;
|
||||
if (current_line_[0] == '\\') break;
|
||||
|
||||
std::vector<std::string> col;
|
||||
SplitStringToVector(line, " \t", true, &col);
|
||||
SplitStringToVector(current_line_, " \t", true, &col);
|
||||
|
||||
if (col.size() < 1 + cur_order ||
|
||||
col.size() > 2 + cur_order ||
|
||||
(cur_order == ngram_counts_.size() && col.size() != 1 + cur_order)) {
|
||||
PARSE_ERR << "Invalid n-gram line '" << line << "'";
|
||||
PARSE_ERR << "Invalid n-gram data line";
|
||||
}
|
||||
++ngram_count;
|
||||
|
||||
// Parse out n-gram logprob and, if present, backoff weight.
|
||||
if (!ConvertStringToReal(col[0], &ngram.logprob)) {
|
||||
PARSE_ERR << "Invalid n-gram logprob '" << col[0] << "'.";
|
||||
PARSE_ERR << "invalid n-gram logprob '" << col[0] << "'";
|
||||
}
|
||||
ngram.backoff = 0.0;
|
||||
if (col.size() > cur_order + 1) {
|
||||
if (!ConvertStringToReal(col[cur_order + 1], &ngram.backoff))
|
||||
PARSE_ERR << "Invalid backoff weight '" << col[cur_order + 1] << "'.";
|
||||
}
|
||||
// Convert to natural log unless the option is set not to.
|
||||
if (!options_.use_log10) {
|
||||
ngram.logprob *= M_LN10;
|
||||
ngram.backoff *= M_LN10;
|
||||
PARSE_ERR << "invalid backoff weight '" << col[cur_order + 1] << "'";
|
||||
}
|
||||
// Convert to natural log.
|
||||
ngram.logprob *= M_LN10;
|
||||
ngram.backoff *= M_LN10;
|
||||
|
||||
ngram.words.resize(cur_order);
|
||||
bool skip_ngram = false;
|
||||
|
@ -188,11 +188,14 @@ void ArpaFileParser::Read(std::istream &is, bool binary) {
|
|||
word = options_.unk_symbol;
|
||||
break;
|
||||
case ArpaParseOptions::kSkipNGram:
|
||||
if (ShouldWarn())
|
||||
KALDI_WARN << LineReference() << " skipped: word '"
|
||||
<< col[1 + index] << "' not in symbol table";
|
||||
skip_ngram = true;
|
||||
break;
|
||||
default:
|
||||
PARSE_ERR << "Word '" << col[1 + index]
|
||||
<< "' not in symbol table.";
|
||||
PARSE_ERR << "word '" << col[1 + index]
|
||||
<< "' not in symbol table";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -204,8 +207,8 @@ void ArpaFileParser::Read(std::istream &is, bool binary) {
|
|||
}
|
||||
// Whichever way we got it, an epsilon is invalid.
|
||||
if (word == 0) {
|
||||
PARSE_ERR << "Epsilon symbol '" << col[1 + index]
|
||||
<< "' is illegal in ARPA LM.";
|
||||
PARSE_ERR << "epsilon symbol '" << col[1 + index]
|
||||
<< "' is illegal in ARPA LM";
|
||||
}
|
||||
ngram.words[index] = word;
|
||||
}
|
||||
|
@ -214,20 +217,36 @@ void ArpaFileParser::Read(std::istream &is, bool binary) {
|
|||
}
|
||||
}
|
||||
if (ngram_count > ngram_counts_[cur_order - 1]) {
|
||||
PARSE_ERR << "Header said there would be " << ngram_counts_[cur_order - 1]
|
||||
<< " n-grams of order " << cur_order << ", but we saw "
|
||||
<< ngram_count;
|
||||
PARSE_ERR << "header said there would be " << ngram_counts_[cur_order - 1]
|
||||
<< " n-grams of order " << cur_order
|
||||
<< ", but we saw more already.";
|
||||
}
|
||||
}
|
||||
|
||||
if (line != "\\end\\") {
|
||||
PARSE_ERR << "Invalid or unexpected directive line '" << line << "', "
|
||||
<< "expected \\end\\.";
|
||||
if (current_line_ != "\\end\\") {
|
||||
PARSE_ERR << "invalid or unexpected directive line, expecting \\end\\";
|
||||
}
|
||||
|
||||
if (warning_count_ > 0 && warning_count_ > (uint32)options_.max_warnings) {
|
||||
KALDI_WARN << "Of " << warning_count_ << " parse warnings, "
|
||||
<< options_.max_warnings << " were reported. Run program with "
|
||||
<< "--max_warnings=-1 to see all warnings";
|
||||
}
|
||||
|
||||
current_line_.empty();
|
||||
ReadComplete();
|
||||
|
||||
#undef PARSE_ERR
|
||||
}
|
||||
|
||||
std::string ArpaFileParser::LineReference() const {
|
||||
std::stringstream ss;
|
||||
ss << "line " << line_number_ << " [" << current_line_ << "]";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
bool ArpaFileParser::ShouldWarn() {
|
||||
return ++warning_count_ <= (uint32)options_.max_warnings;
|
||||
}
|
||||
|
||||
} // namespace kaldi
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include <fst/fst-decl.h>
|
||||
|
||||
#include "base/kaldi-types.h"
|
||||
#include "itf/options-itf.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
||||
|
@ -43,13 +44,22 @@ struct ArpaParseOptions {
|
|||
|
||||
ArpaParseOptions()
|
||||
: bos_symbol(-1), eos_symbol(-1), unk_symbol(-1),
|
||||
oov_handling(kRaiseError), use_log10(false) { }
|
||||
oov_handling(kRaiseError), max_warnings(30) { }
|
||||
|
||||
void Register(OptionsItf *opts) {
|
||||
// Registering only the max_warnings count, since other options are
|
||||
// treated differently by client programs: some want integer symbols,
|
||||
// while other are passed words in their command line.
|
||||
opts->Register("max-arpa-warnings", &max_warnings,
|
||||
"Maximum warnings to report on ARPA parsing, "
|
||||
"0 to disable, -1 to show all");
|
||||
}
|
||||
|
||||
int32 bos_symbol; ///< Symbol for <s>, Required non-epsilon.
|
||||
int32 eos_symbol; ///< Symbol for </s>, Required non-epsilon.
|
||||
int32 unk_symbol; ///< Symbol for <unk>, Required for kReplaceWithUnk.
|
||||
OovHandling oov_handling; ///< How to handle OOV words in the file.
|
||||
bool use_log10; ///< Use log10 for prob and backoff weight, not ln.
|
||||
int32 max_warnings; ///< Maximum warnings to report, <0 unlimited.
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -111,6 +121,14 @@ class ArpaFileParser {
|
|||
/// Inside ConsumeNGram(), provides the current line number.
|
||||
int32 LineNumber() const { return line_number_; }
|
||||
|
||||
/// Inside ConsumeNGram(), returns a formatted reference to the line being
|
||||
/// compiled, to print out as part of diagnostics.
|
||||
std::string LineReference() const;
|
||||
|
||||
/// Increments warning count, and returns true if a warning should be
|
||||
/// printed or false if the count has exceeded the set maximum.
|
||||
bool ShouldWarn();
|
||||
|
||||
/// N-gram counts. Valid in and after a call to HeaderAvailable().
|
||||
const std::vector<int32>& NgramCounts() const { return ngram_counts_; }
|
||||
|
||||
|
@ -118,6 +136,8 @@ class ArpaFileParser {
|
|||
ArpaParseOptions options_;
|
||||
fst::SymbolTable* symbols_; // Not owned.
|
||||
int32 line_number_;
|
||||
uint32 warning_count_;
|
||||
std::string current_line_;
|
||||
std::vector<int32> ngram_counts_;
|
||||
};
|
||||
|
||||
|
|
|
@ -106,18 +106,21 @@ class OptimizedHistKey {
|
|||
uint64 data_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <class HistKey>
|
||||
class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface {
|
||||
public:
|
||||
ArpaLmCompilerImpl(fst::StdVectorFst* fst, Symbol bos_symbol,
|
||||
Symbol eos_symbol, Symbol sub_eps);
|
||||
ArpaLmCompilerImpl(ArpaLmCompiler* parent, fst::StdVectorFst* fst,
|
||||
Symbol sub_eps);
|
||||
|
||||
virtual void ConsumeNGram(const NGram& ngram, bool is_highest);
|
||||
virtual void ConsumeNGram(const NGram &ngram, bool is_highest);
|
||||
|
||||
private:
|
||||
StateId AddStateWithBackoff(HistKey key, float backoff);
|
||||
void CreateBackoff(HistKey key, StateId state, float weight);
|
||||
|
||||
ArpaLmCompiler *parent_; // Not owned.
|
||||
fst::StdVectorFst* fst_; // Not owned.
|
||||
Symbol bos_symbol_;
|
||||
Symbol eos_symbol_;
|
||||
|
@ -131,10 +134,9 @@ class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface {
|
|||
|
||||
template <class HistKey>
|
||||
ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(
|
||||
fst::StdVectorFst* fst, Symbol bos_symbol,
|
||||
Symbol eos_symbol, Symbol sub_eps) : fst_(fst), bos_symbol_(bos_symbol),
|
||||
eos_symbol_(eos_symbol),
|
||||
sub_eps_(sub_eps) {
|
||||
ArpaLmCompiler* parent, fst::StdVectorFst* fst, Symbol sub_eps)
|
||||
: parent_(parent), fst_(fst), bos_symbol_(parent->Options().bos_symbol),
|
||||
eos_symbol_(parent->Options().eos_symbol), sub_eps_(sub_eps) {
|
||||
// The algorithm maintains state per history. The 0-gram is a special state
|
||||
// for emptry history. All unigrams (including BOS) backoff into this state.
|
||||
StateId zerogram = fst_->AddState();
|
||||
|
@ -150,8 +152,8 @@ ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(
|
|||
}
|
||||
|
||||
template <class HistKey>
|
||||
void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(
|
||||
const NGram& ngram, bool is_highest) {
|
||||
void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram &ngram,
|
||||
bool is_highest) {
|
||||
// Generally, we do the following. Suppose we are adding an n-gram "A B
|
||||
// C". Then find the node for "A B", add a new node for "A B C", and connect
|
||||
// them with the arc accepting "C" with the specified weight. Also, add a
|
||||
|
@ -181,7 +183,9 @@ void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(
|
|||
if (source_it == history_.end()) {
|
||||
// There was no "A B", therefore the probability of "A B C" is zero.
|
||||
// Print a warning and discard current n-gram.
|
||||
KALDI_WARN << "No parent gram, skipping";
|
||||
if (parent_->ShouldWarn())
|
||||
KALDI_WARN << parent_->LineReference()
|
||||
<< " skipped: no parent (n-1)-gram exists";
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -225,6 +229,7 @@ void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(
|
|||
|
||||
// Add arc from source to dest, whichever way it was found.
|
||||
fst_->AddArc(source, fst::StdArc(sym, sym, weight, dest));
|
||||
return;
|
||||
}
|
||||
|
||||
// Find or create a new state for n-gram defined by key, and ensure it has a
|
||||
|
@ -266,8 +271,6 @@ inline void ArpaLmCompilerImpl<HistKey>::CreateBackoff(
|
|||
fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ArpaLmCompiler::~ArpaLmCompiler() {
|
||||
if (impl_ != NULL)
|
||||
delete impl_;
|
||||
|
@ -286,25 +289,24 @@ void ArpaLmCompiler::HeaderAvailable() {
|
|||
max_symbol += NgramCounts()[0];
|
||||
|
||||
if (NgramCounts().size() <= 4 && max_symbol < OptimizedHistKey::kMaxData) {
|
||||
impl_ = new ArpaLmCompilerImpl<OptimizedHistKey>(
|
||||
&fst_, Options().bos_symbol, Options().eos_symbol, sub_eps_);
|
||||
impl_ = new ArpaLmCompilerImpl<OptimizedHistKey>(this, &fst_, sub_eps_);
|
||||
} else {
|
||||
impl_ = new ArpaLmCompilerImpl<GeneralHistKey>(
|
||||
&fst_, Options().bos_symbol, Options().eos_symbol, sub_eps_);
|
||||
impl_ = new ArpaLmCompilerImpl<GeneralHistKey>(this, &fst_, sub_eps_);
|
||||
KALDI_LOG << "Reverting to slower state tracking because model is large: "
|
||||
<< NgramCounts().size() << "-gram with symbols up to "
|
||||
<< max_symbol;
|
||||
}
|
||||
}
|
||||
|
||||
void ArpaLmCompiler::ConsumeNGram(const NGram& ngram) {
|
||||
void ArpaLmCompiler::ConsumeNGram(const NGram &ngram) {
|
||||
// <s> is invalid in tails, </s> in heads of an n-gram.
|
||||
for (int i = 0; i < ngram.words.size(); ++i) {
|
||||
if ((i > 0 && ngram.words[i] == Options().bos_symbol) ||
|
||||
(i + 1 < ngram.words.size()
|
||||
&& ngram.words[i] == Options().eos_symbol)) {
|
||||
KALDI_WARN << "In line " << LineNumber()
|
||||
<< ": Skipping n-gram with invalid BOS/EOS placement";
|
||||
if (ShouldWarn())
|
||||
KALDI_WARN << LineReference()
|
||||
<< " skipped: n-gram has invalid BOS/EOS placement";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@ class ArpaLmCompiler : public ArpaFileParser {
|
|||
~ArpaLmCompiler();
|
||||
|
||||
const fst::StdVectorFst& Fst() const { return fst_; }
|
||||
fst::StdVectorFst* MutableFst() { return &fst_; }
|
||||
|
||||
protected:
|
||||
// ArpaFileParser overrides.
|
||||
|
@ -50,6 +51,7 @@ class ArpaLmCompiler : public ArpaFileParser {
|
|||
int sub_eps_;
|
||||
ArpaLmCompilerImplInterface* impl_; // Owned.
|
||||
fst::StdVectorFst fst_;
|
||||
template <class HistKey> friend class ArpaLmCompilerImpl;
|
||||
};
|
||||
|
||||
} // namespace kaldi
|
||||
|
|
|
@ -268,7 +268,7 @@ void ConstArpaLmBuilder::HeaderAvailable() {
|
|||
ngram_order_ = NgramCounts().size();
|
||||
}
|
||||
|
||||
void ConstArpaLmBuilder::ConsumeNGram(const NGram& ngram) {
|
||||
void ConstArpaLmBuilder::ConsumeNGram(const NGram &ngram) {
|
||||
int32 cur_order = ngram.words.size();
|
||||
// If <ngram_order_> is larger than 1, then we do not create LmState for
|
||||
// the final order entry. We only keep the log probability for it.
|
||||
|
@ -1062,15 +1062,9 @@ bool ConstArpaLmDeterministicFst::GetArc(StateId s,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool BuildConstArpaLm(const int32 bos_symbol, const int32 eos_symbol,
|
||||
const int32 unk_symbol,
|
||||
bool BuildConstArpaLm(const ArpaParseOptions& options,
|
||||
const std::string& arpa_rxfilename,
|
||||
const std::string& const_arpa_wxfilename) {
|
||||
ArpaParseOptions options;
|
||||
options.bos_symbol = bos_symbol;
|
||||
options.eos_symbol = eos_symbol;
|
||||
options.unk_symbol = unk_symbol;
|
||||
|
||||
ConstArpaLmBuilder lm_builder(options);
|
||||
KALDI_LOG << "Reading " << arpa_rxfilename;
|
||||
ReadKaldiObject(arpa_rxfilename, &lm_builder);
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
#include "base/kaldi-common.h"
|
||||
#include "fstext/deterministic-fst.h"
|
||||
#include "lm/arpa-file-parser.h"
|
||||
#include "util/common-utils.h"
|
||||
|
||||
namespace kaldi {
|
||||
|
@ -418,8 +419,7 @@ class ConstArpaLmDeterministicFst
|
|||
// Reads in an Arpa format language model and converts it into ConstArpaLm
|
||||
// format. We assume that the words in the input Arpa format language model have
|
||||
// been converted into integers.
|
||||
bool BuildConstArpaLm(const int32 bos_symbol, const int32 eos_symbol,
|
||||
const int32 unk_symbol,
|
||||
bool BuildConstArpaLm(const ArpaParseOptions& options,
|
||||
const std::string& arpa_rxfilename,
|
||||
const std::string& const_arpa_wxfilename);
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
||||
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
||||
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
||||
// MERCHANTABILITY OR NON-INFRINGEMENT.
|
||||
// See the Apache 2 License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
@ -33,7 +33,7 @@ int main(int argc, char *argv[]) {
|
|||
"that wants to rescore lattices. We assume that the words in the\n"
|
||||
"input arpa language model has been converted to integers.\n"
|
||||
"\n"
|
||||
"The program is used joinly with utils/map_arpa_lm.pl to build\n"
|
||||
"The program is used jointly with utils/map_arpa_lm.pl to build\n"
|
||||
"ConstArpaLm format language model. We first map the words in an Arpa\n"
|
||||
"format language model to integers using utils/map_arpa_m.pl, and\n"
|
||||
"then use this program to build a ConstArpaLm format language model.\n"
|
||||
|
@ -44,16 +44,19 @@ int main(int argc, char *argv[]) {
|
|||
|
||||
kaldi::ParseOptions po(usage);
|
||||
|
||||
int32 unk_symbol = -1;
|
||||
int32 bos_symbol = -1;
|
||||
int32 eos_symbol = -1;
|
||||
po.Register("unk-symbol", &unk_symbol,
|
||||
ArpaParseOptions options;
|
||||
options.Register(&po);
|
||||
|
||||
// Ideally, these registrations would be in ArpaParseOptions, but some
|
||||
// programs want integers and other want symbols, so we register them
|
||||
// outside instead.
|
||||
po.Register("unk-symbol", &options.unk_symbol,
|
||||
"Integer corresponds to unknown-word in language model. -1 if "
|
||||
"no such word is provided.");
|
||||
po.Register("bos-symbol", &bos_symbol,
|
||||
po.Register("bos-symbol", &options.bos_symbol,
|
||||
"Integer corresponds to <s>. You must set this to your actual "
|
||||
"BOS integer.");
|
||||
po.Register("eos-symbol", &eos_symbol,
|
||||
po.Register("eos-symbol", &options.eos_symbol,
|
||||
"Integer corresponds to </s>. You must set this to your actual "
|
||||
"EOS integer.");
|
||||
|
||||
|
@ -64,7 +67,7 @@ int main(int argc, char *argv[]) {
|
|||
exit(1);
|
||||
}
|
||||
|
||||
if (bos_symbol == -1 || eos_symbol == -1) {
|
||||
if (options.bos_symbol == -1 || options.eos_symbol == -1) {
|
||||
KALDI_ERR << "Please set --bos-symbol and --eos-symbol.";
|
||||
exit(1);
|
||||
}
|
||||
|
@ -72,8 +75,8 @@ int main(int argc, char *argv[]) {
|
|||
std::string arpa_rxfilename = po.GetArg(1),
|
||||
const_arpa_wxfilename = po.GetOptArg(2);
|
||||
|
||||
bool ans = BuildConstArpaLm(bos_symbol, eos_symbol, unk_symbol,
|
||||
arpa_rxfilename, const_arpa_wxfilename);
|
||||
bool ans = BuildConstArpaLm(options, arpa_rxfilename,
|
||||
const_arpa_wxfilename);
|
||||
if (ans)
|
||||
return 0;
|
||||
else
|
||||
|
|
|
@ -34,6 +34,9 @@ int main(int argc, char *argv[]) {
|
|||
"data/lang/words.txt lm/input.arpa G.fst\n";
|
||||
ParseOptions po(usage);
|
||||
|
||||
ArpaParseOptions options;
|
||||
options.Register(&po);
|
||||
|
||||
// Option flags.
|
||||
std::string bos_symbol = "<s>";
|
||||
std::string eos_symbol = "</s>";
|
||||
|
@ -41,6 +44,7 @@ int main(int argc, char *argv[]) {
|
|||
std::string read_syms_filename;
|
||||
std::string write_syms_filename;
|
||||
bool keep_symbols = false;
|
||||
bool ilabel_sort = true;
|
||||
|
||||
po.Register("bos-symbol", &bos_symbol,
|
||||
"Beginning of sentence symbol");
|
||||
|
@ -56,6 +60,8 @@ int main(int argc, char *argv[]) {
|
|||
po.Register("keep-symbols", &keep_symbols,
|
||||
"Store symbol table with FST. Forced true if "
|
||||
"symbol tables are neiter read or written");
|
||||
po.Register("ilabel-sort", &ilabel_sort,
|
||||
"Ilabel-sort the output FST");
|
||||
|
||||
po.Read(argc, argv);
|
||||
|
||||
|
@ -66,7 +72,6 @@ int main(int argc, char *argv[]) {
|
|||
std::string arpa_rxfilename = po.GetArg(1),
|
||||
fst_wxfilename = po.GetOptArg(2);
|
||||
|
||||
ArpaParseOptions options;
|
||||
int64 disambig_symbol_id = 0;
|
||||
|
||||
fst::SymbolTable* symbols;
|
||||
|
@ -110,6 +115,11 @@ int main(int argc, char *argv[]) {
|
|||
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
|
||||
ReadKaldiObject(arpa_rxfilename, &lm_compiler);
|
||||
|
||||
// Sort the FST in-place if requested by options.
|
||||
if (ilabel_sort) {
|
||||
fst::ArcSort(lm_compiler.MutableFst(), fst::StdILabelCompare());
|
||||
}
|
||||
|
||||
// Write symbols if requested.
|
||||
if (!write_syms_filename.empty()) {
|
||||
kaldi::Output kosym(write_syms_filename, false);
|
||||
|
|
|
@ -220,8 +220,9 @@ inline void Component::Propagate(const CuMatrixBase<BaseFloat> &in,
|
|||
CuMatrix<BaseFloat> *out) {
|
||||
// Check the dims
|
||||
if (input_dim_ != in.NumCols()) {
|
||||
KALDI_ERR << "Non-matching dims! " << TypeToMarker(GetType())
|
||||
<< " input-dim : " << input_dim_ << " data : " << in.NumCols();
|
||||
KALDI_ERR << "Non-matching dims on the input of " << TypeToMarker(GetType())
|
||||
<< " component. The input-dim is " << input_dim_
|
||||
<< ", the data had " << in.NumCols() << " dims.";
|
||||
}
|
||||
// Allocate target buffer
|
||||
out->Resize(in.NumRows(), output_dim_, kSetZero); // reset
|
||||
|
|
|
@ -41,7 +41,7 @@ namespace nnet1 {
|
|||
class SimpleSentenceAveragingComponent : public Component {
|
||||
public:
|
||||
SimpleSentenceAveragingComponent(int32 dim_in, int32 dim_out)
|
||||
: Component(dim_in, dim_out), gradient_boost_(100.0)
|
||||
: Component(dim_in, dim_out), gradient_boost_(100.0), shrinkage_(0.0), only_summing_(false)
|
||||
{ }
|
||||
~SimpleSentenceAveragingComponent()
|
||||
{ }
|
||||
|
@ -56,7 +56,9 @@ class SimpleSentenceAveragingComponent : public Component {
|
|||
while (!is.eof()) {
|
||||
ReadToken(is, false, &token);
|
||||
if (token == "<GradientBoost>") ReadBasicType(is, false, &gradient_boost_);
|
||||
else KALDI_ERR << "Unknown token " << token << ", a typo in config? (GradientBoost)";
|
||||
else if (token == "<Shrinkage>") ReadBasicType(is, false, &shrinkage_);
|
||||
else if (token == "<OnlySumming>") ReadBasicType(is, false, &only_summing_);
|
||||
else KALDI_ERR << "Unknown token " << token << ", a typo in config? (GradientBoost|Shrinkage|OnlySumming)";
|
||||
is >> std::ws; // eat-up whitespace
|
||||
}
|
||||
}
|
||||
|
@ -64,15 +66,29 @@ class SimpleSentenceAveragingComponent : public Component {
|
|||
void ReadData(std::istream &is, bool binary) {
|
||||
ExpectToken(is, binary, "<GradientBoost>");
|
||||
ReadBasicType(is, binary, &gradient_boost_);
|
||||
if(PeekToken(is, binary) == 'S') {
|
||||
ExpectToken(is, binary, "<Shrinkage>");
|
||||
ReadBasicType(is, binary, &shrinkage_);
|
||||
}
|
||||
if(PeekToken(is, binary) == 'O') {
|
||||
ExpectToken(is, binary, "<OnlySumming>");
|
||||
ReadBasicType(is, binary, &only_summing_);
|
||||
}
|
||||
}
|
||||
|
||||
void WriteData(std::ostream &os, bool binary) const {
|
||||
WriteToken(os, binary, "<GradientBoost>");
|
||||
WriteBasicType(os, binary, gradient_boost_);
|
||||
WriteToken(os, binary, "<Shrinkage>");
|
||||
WriteBasicType(os, binary, shrinkage_);
|
||||
WriteToken(os, binary, "<OnlySumming>");
|
||||
WriteBasicType(os, binary, only_summing_);
|
||||
}
|
||||
|
||||
std::string Info() const {
|
||||
return std::string("\n gradient-boost ") + ToString(gradient_boost_);
|
||||
return std::string("\n gradient-boost ") + ToString(gradient_boost_) +
|
||||
", shrinkage: " + ToString(shrinkage_) +
|
||||
", only summing: " + ToString(only_summing_);
|
||||
}
|
||||
std::string InfoGradient() const {
|
||||
return Info();
|
||||
|
@ -81,7 +97,11 @@ class SimpleSentenceAveragingComponent : public Component {
|
|||
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
|
||||
// get the average row-vector,
|
||||
average_row_.Resize(InputDim());
|
||||
average_row_.AddRowSumMat(1.0/in.NumRows(), in, 0.0);
|
||||
if (only_summing_) {
|
||||
average_row_.AddRowSumMat(1.0, in, 0.0);
|
||||
} else {
|
||||
average_row_.AddRowSumMat(1.0/(in.NumRows()+shrinkage_), in, 0.0);
|
||||
}
|
||||
// copy it on the output,
|
||||
out->AddVecToRows(1.0, average_row_, 0.0);
|
||||
}
|
||||
|
@ -97,7 +117,11 @@ class SimpleSentenceAveragingComponent : public Component {
|
|||
//
|
||||
// getting the average output diff,
|
||||
average_diff_.Resize(OutputDim());
|
||||
average_diff_.AddRowSumMat(1.0/out_diff.NumRows(), out_diff, 0.0);
|
||||
if (only_summing_) {
|
||||
average_diff_.AddRowSumMat(1.0, out_diff, 0.0);
|
||||
} else {
|
||||
average_diff_.AddRowSumMat(1.0/(out_diff.NumRows()+shrinkage_), out_diff, 0.0);
|
||||
}
|
||||
// copy the derivative into the input diff, (applying gradient-boost!!)
|
||||
in_diff->AddVecToRows(gradient_boost_, average_diff_, 0.0);
|
||||
}
|
||||
|
@ -106,6 +130,9 @@ class SimpleSentenceAveragingComponent : public Component {
|
|||
CuVector<BaseFloat> average_row_; ///< auxiliary buffer for forward propagation,
|
||||
CuVector<BaseFloat> average_diff_; ///< auxiliary buffer for backpropagation,
|
||||
BaseFloat gradient_boost_; ///< increase of gradient applied in backpropagation,
|
||||
BaseFloat shrinkage_; ///< Number of 'imaginary' zero-vectors in the average
|
||||
///< (shrinks the average vector for shorter sentences),
|
||||
bool only_summing_; ///< Removes normalization term from arithmetic mean (when true).
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -515,7 +515,7 @@ void CachingOptimizingCompiler::ReadCache(std::istream &is, bool binary) {
|
|||
bool read_cache = (opt_config_ == opt_config_cached);
|
||||
|
||||
if (read_cache) {
|
||||
size_t computation_cache_size;
|
||||
int32 computation_cache_size;
|
||||
ExpectToken(is, binary, "<ComputationCacheSize>");
|
||||
ReadBasicType(is, binary, &computation_cache_size);
|
||||
KALDI_ASSERT(computation_cache_size >= 0);
|
||||
|
|
Загрузка…
Ссылка в новой задаче