110 строки
3.7 KiB
Bash
Executable File
110 строки
3.7 KiB
Bash
Executable File
#!/bin/bash -v
|
|
|
|
MARIAN=../../build
|
|
|
|
# if we are in WSL, we need to add '.exe' to the tool names
|
|
if [ -e "/bin/wslpath" ]
|
|
then
|
|
EXT=.exe
|
|
fi
|
|
|
|
MARIAN_TRAIN=$MARIAN/marian$EXT
|
|
MARIAN_DECODER=$MARIAN/marian-decoder$EXT
|
|
MARIAN_VOCAB=$MARIAN/marian-vocab$EXT
|
|
MARIAN_SCORER=$MARIAN/marian-scorer$EXT
|
|
|
|
# set chosen gpus
|
|
GPUS=0
|
|
if [ $# -ne 0 ]
|
|
then
|
|
GPUS=$@
|
|
fi
|
|
echo Using GPUs: $GPUS
|
|
|
|
if [ ! -e $MARIAN_TRAIN ]
|
|
then
|
|
echo "marian is not installed in $MARIAN, you need to compile the toolkit first"
|
|
exit 1
|
|
fi
|
|
|
|
if [ ! -e ../tools/moses-scripts ] || [ ! -e ../tools/subword-nmt ] || [ ! -e ../tools/sacreBLEU ]
|
|
then
|
|
echo "missing tools in ../tools, you need to download them first"
|
|
exit 1
|
|
fi
|
|
|
|
if [ ! -e "data/corpus.en" ]
|
|
then
|
|
./scripts/download-files.sh
|
|
fi
|
|
|
|
export MODEL=`pwd`/../../../keep
|
|
|
|
# preprocess data
|
|
if [ ! -e "data/corpus.bpe.en" ]
|
|
then
|
|
LC_ALL=C.UTF-8 ../tools/sacreBLEU/sacrebleu.py -t wmt13 -l en-de --echo src > data/valid.en
|
|
LC_ALL=C.UTF-8 ../tools/sacreBLEU/sacrebleu.py -t wmt13 -l en-de --echo ref > data/valid.de
|
|
|
|
LC_ALL=C.UTF-8 ../tools/sacreBLEU/sacrebleu.py -t wmt14 -l en-de --echo src > data/test2014.en
|
|
LC_ALL=C.UTF-8 ../tools/sacreBLEU/sacrebleu.py -t wmt15 -l en-de --echo src > data/test2015.en
|
|
LC_ALL=C.UTF-8 ../tools/sacreBLEU/sacrebleu.py -t wmt16 -l en-de --echo src > data/test2016.en
|
|
|
|
./scripts/preprocess-data.sh
|
|
fi
|
|
|
|
# create common vocabulary
|
|
if [ ! -e "$MODEL/vocab.ende.yml" ]
|
|
then
|
|
cat data/corpus.bpe.en data/corpus.bpe.de | $MARIAN_VOCAB --max-size 36000 > $MODEL/vocab.ende.yml
|
|
fi
|
|
|
|
# train model
|
|
if [ ! -e "$MODEL/model.npz" ]
|
|
then
|
|
$MARIAN_TRAIN \
|
|
--model $MODEL/model.npz --type transformer \
|
|
--train-sets data/corpus.bpe.en data/corpus.bpe.de \
|
|
--max-length 100 \
|
|
--vocabs $MODEL/vocab.ende.yml $MODEL/vocab.ende.yml \
|
|
--mini-batch-fit -w 22000 --maxi-batch 1000 \
|
|
--early-stopping 10 --cost-type=ce-mean-words \
|
|
--valid-freq 5000 --save-freq 5000 --disp-freq 500 \
|
|
--valid-metrics ce-mean-words perplexity translation \
|
|
--valid-sets data/valid.bpe.en data/valid.bpe.de \
|
|
--valid-script-path "bash ./scripts/validate.sh" \
|
|
--valid-translation-output data/valid.bpe.en.output --quiet-translation \
|
|
--valid-mini-batch 64 \
|
|
--beam-size 6 --normalize 0.6 \
|
|
--log $MODEL/train.log --valid-log $MODEL/valid.log \
|
|
--enc-depth 6 --dec-depth 6 \
|
|
--transformer-heads 8 \
|
|
--transformer-postprocess-emb d \
|
|
--transformer-postprocess dan \
|
|
--transformer-dropout 0.1 --label-smoothing 0.1 \
|
|
--learn-rate 0.0003 --lr-warmup 16000 --lr-decay-inv-sqrt 16000 --lr-report \
|
|
--optimizer-params 0.9 0.98 1e-09 --clip-norm 5 \
|
|
--tied-embeddings-all \
|
|
--devices $GPUS --sync-sgd --seed 1111 \
|
|
--exponential-smoothing
|
|
fi
|
|
|
|
# find best model on dev set
|
|
ITER=`cat $MODEL/valid.log | grep translation | sort -rg -k12,12 -t' ' | cut -f8 -d' ' | head -n1`
|
|
|
|
# translate test sets
|
|
for prefix in test2014 test2015 test2016
|
|
do
|
|
cat data/$prefix.bpe.en \
|
|
| $MARIAN_DECODER -c $MODEL/model.npz.decoder.yml -m $MODEL/model.iter$ITER.npz -d $GPUS -b 12 -n -w 6000 \
|
|
| sed 's/\@\@ //g' \
|
|
| ../tools/moses-scripts/scripts/recaser/detruecase.perl \
|
|
| ../tools/moses-scripts/scripts/tokenizer/detokenizer.perl -l de \
|
|
> data/$prefix.de.output
|
|
done
|
|
|
|
# calculate bleu scores on test sets
|
|
LC_ALL=C.UTF-8 ../tools/sacreBLEU/sacrebleu.py -t wmt14 -l en-de < data/test2014.de.output
|
|
LC_ALL=C.UTF-8 ../tools/sacreBLEU/sacrebleu.py -t wmt15 -l en-de < data/test2015.de.output
|
|
LC_ALL=C.UTF-8 ../tools/sacreBLEU/sacrebleu.py -t wmt16 -l en-de < data/test2016.de.output
|