зеркало из https://github.com/microsoft/inmt.git
Update OpenNMT-py to c20dbeac02688918607637f5f30ec73c0f17d817
This commit is contained in:
Родитель
ea5d36d9e7
Коммит
70ef77eede
|
@ -1,6 +1,5 @@
|
||||||
model/
|
model/
|
||||||
*.pyc
|
*.pyc
|
||||||
db.sqlite3
|
db.sqlite3
|
||||||
opennmt/.git
|
|
||||||
pred.txt
|
pred.txt
|
||||||
conf.py
|
conf.py
|
|
@ -4,5 +4,6 @@ RUN mkdir /inmt
|
||||||
WORKDIR /inmt
|
WORKDIR /inmt
|
||||||
COPY requirements.txt /inmt/
|
COPY requirements.txt /inmt/
|
||||||
RUN pip install -r requirements.txt
|
RUN pip install -r requirements.txt
|
||||||
|
RUN cd OpenNMT-py
|
||||||
|
RUN python setup.py install
|
||||||
COPY . /inmt/
|
COPY . /inmt/
|
||||||
|
|
||||||
|
|
19
README.md
19
README.md
|
@ -3,25 +3,32 @@ Interactive Machine Translation app uses Django and jQuery as its tech stack. Pl
|
||||||
|
|
||||||
# Installation Instructions
|
# Installation Instructions
|
||||||
|
|
||||||
Make a new model folder using `mkdir model` where the models need to be placed. Models can be downloaded from [https://microsoft-my.sharepoint.com/:f:/p/t-sesan/Evsn3riZxktJuterr5A09lABTVjhaL_NoH430IMgkzws9Q?e=VXzX5T].
|
1. Clone INMT and prepare MT models:
|
||||||
|
```
|
||||||
|
git clone https://github.com/microsoft/inmt
|
||||||
|
```
|
||||||
|
2. Make a new model folder using `mkdir model` where the models need to be placed. Models can be downloaded from [here](https://microsoft-my.sharepoint.com/:f:/p/t-sesan/Evsn3riZxktJuterr5A09lABTVjhaL_NoH430IMgkzws9Q?e=VXzX5T). These contain English to Hindi translation models in both directions. If you want to train your own models, refer to [Training MT Models](#training-mt-models)
|
||||||
|
|
||||||
|
3. Rest of the installation can be carried out either bare or using docker. Docker is preferable for its ease of installation.
|
||||||
|
|
||||||
## Docker Installation
|
## Docker Installation
|
||||||
Assuming you have docker setup in your system, simply run `docker-compose up -d`. This application requires atleast 4GB of memory in order to run. Allot your docker memory accordingly.
|
Assuming you have docker setup in your system, simply run `docker-compose up -d`. This application requires atleast 4GB of memory in order to run. Allot your docker memory accordingly.
|
||||||
|
|
||||||
## Bare Installation
|
## Bare Installation
|
||||||
1. Install dependencies using - `python -m pip install -r requirements.txt`. Be sure to check your python version. This tool is compatible with Python3.
|
1. Install dependencies using - `python -m pip install -r requirements.txt`. Be sure to check your python version. This tool is compatible with Python3.
|
||||||
2. Run the migrations and start the server - `python manage.py makemigrations && python manage.py migrate && python manage.py runserver`
|
2. Install OpenNMT dependences using - `cd opennmt & python setup.py install & cd -`
|
||||||
3. The server opens on port 8000 by default. Open `localhost:8000/simple` for the simple interface.
|
3. Run the migrations and start the server - `python manage.py makemigrations && python manage.py migrate && python manage.py runserver`
|
||||||
|
4. The server opens on port 8000 by default. Open `localhost:8000/simple` for the simple interface.
|
||||||
|
|
||||||
## Training MT Models
|
## Training MT Models
|
||||||
OpenNMT is used as the translation engine to power INMT. In order to train your own models, you need parallel sentences in your desired language. The basic instructions are listed as follows:
|
OpenNMT is used as the translation engine to power INMT. In order to train your own models, you need parallel sentences in your desired language. The basic instructions are listed as follows:
|
||||||
1. Go to opennmt folder: `cd opennmt`
|
1. Go to opennmt folder: `cd opennmt`
|
||||||
2. Preprocess parallel (src & tgt) sentences:
|
2. Preprocess parallel (src & tgt) sentences:
|
||||||
```python preprocess.py -train_src <src_lang_train> -train_tgt <tgt_lang_train> -valid_src <src_lang_valid> -valid_tgt <tgt_lang_valid> -save_data <processed_data>```
|
```onmt_preprocess -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo```
|
||||||
3. Train your model (with GPUs):
|
3. Train your model (with GPUs):
|
||||||
```python train.py -data <processed_data> -save_model <model_name> -gpu_ranks 0```
|
```onmt_train -data data/demo -save_model demo-model```
|
||||||
|
|
||||||
For more information on the training process, refer to [OpenNMT docs](https://opennmt.net/OpenNMT-py/quickstart.html).
|
For advanced instructions on the training process, refer to [OpenNMT docs](https://opennmt.net/OpenNMT-py/quickstart.html).
|
||||||
|
|
||||||
# Contributing
|
# Contributing
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
dist: xenial
|
dist: xenial
|
||||||
language: python
|
language: python
|
||||||
python:
|
python:
|
||||||
- "2.7"
|
- "3.6"
|
||||||
- "3.5"
|
|
||||||
git:
|
git:
|
||||||
depth: false
|
depth: false
|
||||||
addons:
|
addons:
|
||||||
|
@ -14,13 +13,30 @@ addons:
|
||||||
- sox
|
- sox
|
||||||
before_install:
|
before_install:
|
||||||
# Install CPU version of PyTorch.
|
# Install CPU version of PyTorch.
|
||||||
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp27-cp27mu-linux_x86_64.whl; fi
|
- if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install torch==1.4.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html; fi
|
||||||
- if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp35-cp35m-linux_x86_64.whl; fi
|
- pip install --upgrade setuptools
|
||||||
- pip install -r requirements.txt
|
|
||||||
- pip install -r requirements.opt.txt
|
- pip install -r requirements.opt.txt
|
||||||
install:
|
|
||||||
- python setup.py install
|
- python setup.py install
|
||||||
|
env:
|
||||||
|
global:
|
||||||
|
# Doctr deploy key for OpenNMT/OpenNMT-py
|
||||||
|
- secure: "gL0Soefo1cQgAqwiHUrlNyZd/+SI1eJAAjLD3BEDQWXW160eXyjQAAujGgJoCirjOM7cPHVwLzwmK3S7Y3PVM3JOZguOX5Yl4uxMh/mhiEM+RG77SZyv4OGoLFsEQ8RTvIdYdtP6AwyjlkRDXvZql88TqFNYjpXDu8NG+JwEfiIoGIDYxxZ5SlbrZN0IqmQSZ4/CsV6VQiuq99Jn5kqi4MnUZBTcmhqjaztCP1omvsMRdbrG2IVhDKQOCDIO0kaPJrMy2SGzP4GV7ar52bdBtpeP3Xbm6ZOuhDNfds7M/OMHp1wGdl7XwKtolw9MeXhnGBC4gcrqhhMfcQ6XtfVLMLnsB09Ezl3FXX5zWgTB5Pm0X6TgnGrMA25MAdVqKGJpfqZxOKTh4EMb04b6OXrVbxZ88mp+V0NopuxwlTPD8PMfYLWlTe9chh1BnT0iQlLqeA4Hv3+NdpiFb4aq3V3cWTTgMqOoWSGq4t318pqIZ3qbBXBq12DLFgO5n6+M6ZrdxbDUGQvgh8nAiZcIEdodKJ4ABHi1SNCeWOzCoedUdegcbjShHfkMVmNKrncB18aRWwQ3GQJ5qdkjgJmC++uZmkS6+GPM8UmmAy1ZIkRW0aWiitjG6teqtvUHOofNd/TCxX4bhnxAj+mtVIrARCE/ci8topJ6uG4wVJ1TrIkUlAY="
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
include:
|
||||||
|
- name: "Flake8 Lint Check"
|
||||||
|
env: LINT_CHECK
|
||||||
|
install: pip install flake8 pep8-naming==0.7.0
|
||||||
|
script: flake8
|
||||||
|
- name: "Build Docs"
|
||||||
|
install:
|
||||||
|
- pip install doctr
|
||||||
|
script:
|
||||||
|
- pip install -r docs/requirements.txt
|
||||||
|
- set -e
|
||||||
|
- cd docs/ && make html && cd ..
|
||||||
|
- doctr deploy --built-docs docs/build/html/ .
|
||||||
|
- name: "Unit tests"
|
||||||
# Please also add tests to `test/pull_request_chk.sh`.
|
# Please also add tests to `test/pull_request_chk.sh`.
|
||||||
script:
|
script:
|
||||||
- wget -O /tmp/im2text.tgz http://lstm.seas.harvard.edu/latex/im2text_small.tgz; tar zxf /tmp/im2text.tgz -C /tmp/; head /tmp/im2text/src-train.txt > /tmp/im2text/src-train-head.txt; head /tmp/im2text/tgt-train.txt > /tmp/im2text/tgt-train-head.txt; head /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt
|
- wget -O /tmp/im2text.tgz http://lstm.seas.harvard.edu/latex/im2text_small.tgz; tar zxf /tmp/im2text.tgz -C /tmp/; head /tmp/im2text/src-train.txt > /tmp/im2text/src-train-head.txt; head /tmp/im2text/tgt-train.txt > /tmp/im2text/tgt-train-head.txt; head /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt
|
||||||
|
@ -47,6 +63,9 @@ script:
|
||||||
# test nmt preprocessing w/ sharding and training w/copy
|
# test nmt preprocessing w/ sharding and training w/copy
|
||||||
- head -50 data/src-val.txt > /tmp/src-val.txt; head -50 data/tgt-val.txt > /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -shard_size 25 -dynamic_dict -save_data /tmp/q -src_vocab_size 1000 -tgt_vocab_size 1000; python train.py -data /tmp/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -copy_attn -train_steps 10 -pool_factor 10 && rm -rf /tmp/q*.pt
|
- head -50 data/src-val.txt > /tmp/src-val.txt; head -50 data/tgt-val.txt > /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -shard_size 25 -dynamic_dict -save_data /tmp/q -src_vocab_size 1000 -tgt_vocab_size 1000; python train.py -data /tmp/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -copy_attn -train_steps 10 -pool_factor 10 && rm -rf /tmp/q*.pt
|
||||||
|
|
||||||
|
# test Graph neural network preprocessing and training
|
||||||
|
- cp data/ggnnsrc.txt /tmp/src-val.txt; cp data/ggnntgt.txt /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -src_seq_length 1000 -tgt_seq_length 30 -src_vocab data/ggnnsrcvocab.txt -tgt_vocab data/ggnntgtvocab.txt -dynamic_dict -save_data /tmp/q ; python train.py -data /tmp/q -encoder_type ggnn -layers 2 -decoder_type rnn -rnn_size 256 -learning_rate 0.1 -learning_rate_decay 0.8 -global_attention general -batch_size 32 -word_vec_size 256 -bridge -train_steps 10 -src_vocab data/ggnnsrcvocab.txt -n_edge_types 9 -state_dim 256 -n_steps 10 -n_node 64 && rm -rf /tmp/q*.pt
|
||||||
|
|
||||||
# test im2text preprocessing and training
|
# test im2text preprocessing and training
|
||||||
- head -50 /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head -50 /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt; python preprocess.py -data_type img -src_dir /tmp/im2text/images -train_src /tmp/im2text/src-val-head.txt -train_tgt /tmp/im2text/tgt-val-head.txt -valid_src /tmp/im2text/src-val-head.txt -valid_tgt /tmp/im2text/tgt-val-head.txt -save_data /tmp/im2text/q -tgt_seq_length 100; python train.py -model_type img -data /tmp/im2text/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 -pool_factor 10 && rm -rf /tmp/im2text/q*.pt
|
- head -50 /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head -50 /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt; python preprocess.py -data_type img -src_dir /tmp/im2text/images -train_src /tmp/im2text/src-val-head.txt -train_tgt /tmp/im2text/tgt-val-head.txt -valid_src /tmp/im2text/src-val-head.txt -valid_tgt /tmp/im2text/tgt-val-head.txt -save_data /tmp/im2text/q -tgt_seq_length 100; python train.py -model_type img -data /tmp/im2text/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 -pool_factor 10 && rm -rf /tmp/im2text/q*.pt
|
||||||
# test speech2text preprocessing and training
|
# test speech2text preprocessing and training
|
||||||
|
@ -57,24 +76,3 @@ script:
|
||||||
- python translate.py -model onmt/tests/test_model2.pt -src data/morph/src.valid -verbose -batch_size 10 -beam_size 1 -seed 1 -random_sampling_topk "-1" -random_sampling_temp 0.0001 -tgt data/morph/tgt.valid -out /tmp/trans; diff data/morph/tgt.valid /tmp/trans
|
- python translate.py -model onmt/tests/test_model2.pt -src data/morph/src.valid -verbose -batch_size 10 -beam_size 1 -seed 1 -random_sampling_topk "-1" -random_sampling_temp 0.0001 -tgt data/morph/tgt.valid -out /tmp/trans; diff data/morph/tgt.valid /tmp/trans
|
||||||
# test tool
|
# test tool
|
||||||
- PYTHONPATH=$PYTHONPATH:. python tools/extract_embeddings.py -model onmt/tests/test_model.pt
|
- PYTHONPATH=$PYTHONPATH:. python tools/extract_embeddings.py -model onmt/tests/test_model.pt
|
||||||
|
|
||||||
env:
|
|
||||||
global:
|
|
||||||
# Doctr deploy key for OpenNMT/OpenNMT-py
|
|
||||||
- secure: "gL0Soefo1cQgAqwiHUrlNyZd/+SI1eJAAjLD3BEDQWXW160eXyjQAAujGgJoCirjOM7cPHVwLzwmK3S7Y3PVM3JOZguOX5Yl4uxMh/mhiEM+RG77SZyv4OGoLFsEQ8RTvIdYdtP6AwyjlkRDXvZql88TqFNYjpXDu8NG+JwEfiIoGIDYxxZ5SlbrZN0IqmQSZ4/CsV6VQiuq99Jn5kqi4MnUZBTcmhqjaztCP1omvsMRdbrG2IVhDKQOCDIO0kaPJrMy2SGzP4GV7ar52bdBtpeP3Xbm6ZOuhDNfds7M/OMHp1wGdl7XwKtolw9MeXhnGBC4gcrqhhMfcQ6XtfVLMLnsB09Ezl3FXX5zWgTB5Pm0X6TgnGrMA25MAdVqKGJpfqZxOKTh4EMb04b6OXrVbxZ88mp+V0NopuxwlTPD8PMfYLWlTe9chh1BnT0iQlLqeA4Hv3+NdpiFb4aq3V3cWTTgMqOoWSGq4t318pqIZ3qbBXBq12DLFgO5n6+M6ZrdxbDUGQvgh8nAiZcIEdodKJ4ABHi1SNCeWOzCoedUdegcbjShHfkMVmNKrncB18aRWwQ3GQJ5qdkjgJmC++uZmkS6+GPM8UmmAy1ZIkRW0aWiitjG6teqtvUHOofNd/TCxX4bhnxAj+mtVIrARCE/ci8topJ6uG4wVJ1TrIkUlAY="
|
|
||||||
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- env: LINT_CHECK
|
|
||||||
python: "2.7"
|
|
||||||
install: pip install flake8 pep8-naming==0.7.0
|
|
||||||
script: flake8
|
|
||||||
- python: "3.5"
|
|
||||||
install:
|
|
||||||
- python setup.py install
|
|
||||||
- pip install doctr
|
|
||||||
script:
|
|
||||||
- pip install -r docs/requirements.txt
|
|
||||||
- set -e
|
|
||||||
- cd docs/ && make html && cd ..
|
|
||||||
- doctr deploy --built-docs docs/build/html/ .
|
|
||||||
|
|
|
@ -3,7 +3,60 @@
|
||||||
|
|
||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [1.1.1](https://github.com/OpenNMT/OpenNMT-py/tree/1.1.1) (2020-03-20)
|
||||||
### Fixes and improvements
|
### Fixes and improvements
|
||||||
|
* Fix backcompatibility when no 'corpus_id' field (c313c28)
|
||||||
|
|
||||||
|
## [1.1.0](https://github.com/OpenNMT/OpenNMT-py/tree/1.1.0) (2020-03-19)
|
||||||
|
### New features
|
||||||
|
* Support CTranslate2 models in REST server (91d5d57)
|
||||||
|
* Extend support for custom preprocessing/postprocessing function in REST server by using return dictionaries (d14613d, 9619ac3, 92a7ba5)
|
||||||
|
* Experimental: BART-like source noising (5940dcf)
|
||||||
|
|
||||||
|
### Fixes and improvements
|
||||||
|
* Add options to CTranslate2 release (e442f3f)
|
||||||
|
* Fix dataset shard order (458fc48)
|
||||||
|
* Rotate only the server logs, not training (189583a)
|
||||||
|
* Fix alignment error with empty prediction (91287eb)
|
||||||
|
|
||||||
|
## [1.0.2](https://github.com/OpenNMT/OpenNMT-py/tree/1.0.2) (2020-03-05)
|
||||||
|
### Fixes and improvements
|
||||||
|
* Enable CTranslate2 conversion of Transformers with relative position (db11135)
|
||||||
|
* Adapt `-replace_unk` to use with learned alignments if they exist (7625b53)
|
||||||
|
|
||||||
|
## [1.0.1](https://github.com/OpenNMT/OpenNMT-py/tree/1.0.1) (2020-02-17)
|
||||||
|
### Fixes and improvements
|
||||||
|
* Ctranslate2 conversion handled in release script (1b50e0c)
|
||||||
|
* Use `attention_dropout` properly in MHA (f5c9cd4)
|
||||||
|
* Update apex FP16_Optimizer path (d3e2268)
|
||||||
|
* Some REST server optimizations
|
||||||
|
* Fix and add some docs
|
||||||
|
|
||||||
|
## [1.0.0](https://github.com/OpenNMT/OpenNMT-py/tree/1.0.0) (2019-10-01)
|
||||||
|
### New features
|
||||||
|
* Implementation of "Jointly Learning to Align & Translate with Transformer" (@Zenglinxiao)
|
||||||
|
|
||||||
|
### Fixes and improvements
|
||||||
|
* Add nbest support to REST server (@Zenglinxiao)
|
||||||
|
* Merge greedy and beam search codepaths (@Zenglinxiao)
|
||||||
|
* Fix "block ngram repeats" (@KaijuML, @pltrdy)
|
||||||
|
* Small fixes, some more docs
|
||||||
|
|
||||||
|
## [1.0.0.rc2](https://github.com/OpenNMT/OpenNMT-py/tree/1.0.0.rc1) (2019-10-01)
|
||||||
|
* Fix Apex / FP16 training (Apex new API is buggy)
|
||||||
|
* Multithread preprocessing way faster (Thanks @francoishernandez)
|
||||||
|
* Pip Installation v1.0.0.rc1 (thanks @pltrdy)
|
||||||
|
|
||||||
|
## [0.9.2](https://github.com/OpenNMT/OpenNMT-py/tree/0.9.2) (2019-09-04)
|
||||||
|
* Switch to Pytorch 1.2
|
||||||
|
* Pre/post processing on the translation server
|
||||||
|
* option to remove the FFN layer in AAN + AAN optimization (faster)
|
||||||
|
* Coverage loss (per Abisee paper 2017) implementation
|
||||||
|
* Video Captioning task: Thanks Dylan Flaute!
|
||||||
|
* Token batch at inference
|
||||||
|
* Small fixes and add-ons
|
||||||
|
|
||||||
|
|
||||||
## [0.9.1](https://github.com/OpenNMT/OpenNMT-py/tree/0.9.1) (2019-06-13)
|
## [0.9.1](https://github.com/OpenNMT/OpenNMT-py/tree/0.9.1) (2019-06-13)
|
||||||
* New mechanism for MultiGPU training "1 batch producer / multi batch consumers"
|
* New mechanism for MultiGPU training "1 batch producer / multi batch consumers"
|
||||||
|
|
|
@ -1,2 +0,0 @@
|
||||||
FROM pytorch/pytorch:latest
|
|
||||||
RUN git clone https://github.com/OpenNMT/OpenNMT-py.git && cd OpenNMT-py && pip install -r requirements.txt && python setup.py install
|
|
|
@ -3,7 +3,7 @@
|
||||||
[![Build Status](https://travis-ci.org/OpenNMT/OpenNMT-py.svg?branch=master)](https://travis-ci.org/OpenNMT/OpenNMT-py)
|
[![Build Status](https://travis-ci.org/OpenNMT/OpenNMT-py.svg?branch=master)](https://travis-ci.org/OpenNMT/OpenNMT-py)
|
||||||
[![Run on FH](https://img.shields.io/badge/Run%20on-FloydHub-blue.svg)](https://floydhub.com/run?template=https://github.com/OpenNMT/OpenNMT-py)
|
[![Run on FH](https://img.shields.io/badge/Run%20on-FloydHub-blue.svg)](https://floydhub.com/run?template=https://github.com/OpenNMT/OpenNMT-py)
|
||||||
|
|
||||||
This is a [Pytorch](https://github.com/pytorch/pytorch)
|
This is a [PyTorch](https://github.com/pytorch/pytorch)
|
||||||
port of [OpenNMT](https://github.com/OpenNMT/OpenNMT),
|
port of [OpenNMT](https://github.com/OpenNMT/OpenNMT),
|
||||||
an open-source (MIT) neural machine translation system. It is designed to be research friendly to try out new ideas in translation, summary, image-to-text, morphology, and many other domains. Some companies have proven the code to be production ready.
|
an open-source (MIT) neural machine translation system. It is designed to be research friendly to try out new ideas in translation, summary, image-to-text, morphology, and many other domains. Some companies have proven the code to be production ready.
|
||||||
|
|
||||||
|
@ -28,32 +28,43 @@ Table of Contents
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
All dependencies can be installed via:
|
Install `OpenNMT-py` from `pip`:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
pip install OpenNMT-py
|
||||||
```
|
```
|
||||||
|
|
||||||
NOTE: If you have MemoryError in the install try to use:
|
or from the sources:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt --no-cache-dir
|
git clone https://github.com/OpenNMT/OpenNMT-py.git
|
||||||
|
cd OpenNMT-py
|
||||||
|
python setup.py install
|
||||||
```
|
```
|
||||||
Note that we currently only support PyTorch 1.1 (should work with 1.0)
|
|
||||||
|
Note: If you have MemoryError in the install try to use `pip` with `--no-cache-dir`.
|
||||||
|
|
||||||
|
*(Optional)* some advanced features (e.g. working audio, image or pretrained models) requires extra packages, you can install it with:
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.opt.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
|
||||||
|
- some features require Python 3.5 and after (eg: Distributed multigpu, entmax)
|
||||||
|
- we currently only support PyTorch 1.4
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- [data preprocessing](http://opennmt.net/OpenNMT-py/options/preprocess.html)
|
- [Seq2Seq models (encoder-decoder) with multiple RNN cells (lstm/gru) and attention (dotprod/mlp) types](http://opennmt.net/OpenNMT-py/options/train.html#model-encoder-decoder)
|
||||||
- [Inference (translation) with batching and beam search](http://opennmt.net/OpenNMT-py/options/translate.html)
|
- [Transformer models](http://opennmt.net/OpenNMT-py/FAQ.html#how-do-i-use-the-transformer-model)
|
||||||
- [Multiple source and target RNN (lstm/gru) types and attention (dotprod/mlp) types](http://opennmt.net/OpenNMT-py/options/train.html#model-encoder-decoder)
|
|
||||||
- [TensorBoard](http://opennmt.net/OpenNMT-py/options/train.html#logging)
|
|
||||||
- [Source word features](http://opennmt.net/OpenNMT-py/options/train.html#model-embeddings)
|
|
||||||
- [Pretrained Embeddings](http://opennmt.net/OpenNMT-py/FAQ.html#how-do-i-use-pretrained-embeddings-e-g-glove)
|
|
||||||
- [Copy and Coverage Attention](http://opennmt.net/OpenNMT-py/options/train.html#model-attention)
|
- [Copy and Coverage Attention](http://opennmt.net/OpenNMT-py/options/train.html#model-attention)
|
||||||
|
- [Pretrained Embeddings](http://opennmt.net/OpenNMT-py/FAQ.html#how-do-i-use-pretrained-embeddings-e-g-glove)
|
||||||
|
- [Source word features](http://opennmt.net/OpenNMT-py/options/train.html#model-embeddings)
|
||||||
- [Image-to-text processing](http://opennmt.net/OpenNMT-py/im2text.html)
|
- [Image-to-text processing](http://opennmt.net/OpenNMT-py/im2text.html)
|
||||||
- [Speech-to-text processing](http://opennmt.net/OpenNMT-py/speech2text.html)
|
- [Speech-to-text processing](http://opennmt.net/OpenNMT-py/speech2text.html)
|
||||||
- ["Attention is all you need"](http://opennmt.net/OpenNMT-py/FAQ.html#how-do-i-use-the-transformer-model)
|
- [TensorBoard logging](http://opennmt.net/OpenNMT-py/options/train.html#logging)
|
||||||
- [Multi-GPU](http://opennmt.net/OpenNMT-py/FAQ.html##do-you-support-multi-gpu)
|
- [Multi-GPU training](http://opennmt.net/OpenNMT-py/FAQ.html##do-you-support-multi-gpu)
|
||||||
|
- [Data preprocessing](http://opennmt.net/OpenNMT-py/options/preprocess.html)
|
||||||
|
- [Inference (translation) with batching and beam search](http://opennmt.net/OpenNMT-py/options/translate.html)
|
||||||
- Inference time loss functions.
|
- Inference time loss functions.
|
||||||
- [Conv2Conv convolution model]
|
- [Conv2Conv convolution model]
|
||||||
- SRU "RNNs faster than CNN" paper
|
- SRU "RNNs faster than CNN" paper
|
||||||
|
@ -67,7 +78,7 @@ Note that we currently only support PyTorch 1.1 (should work with 1.0)
|
||||||
### Step 1: Preprocess the data
|
### Step 1: Preprocess the data
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo
|
onmt_preprocess -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo
|
||||||
```
|
```
|
||||||
|
|
||||||
We will be working with some example data in `data/` folder.
|
We will be working with some example data in `data/` folder.
|
||||||
|
@ -94,21 +105,21 @@ Internally the system never touches the words themselves, but uses these indices
|
||||||
### Step 2: Train the model
|
### Step 2: Train the model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python train.py -data data/demo -save_model demo-model
|
onmt_train -data data/demo -save_model demo-model
|
||||||
```
|
```
|
||||||
|
|
||||||
The main train command is quite simple. Minimally it takes a data file
|
The main train command is quite simple. Minimally it takes a data file
|
||||||
and a save file. This will run the default model, which consists of a
|
and a save file. This will run the default model, which consists of a
|
||||||
2-layer LSTM with 500 hidden units on both the encoder/decoder.
|
2-layer LSTM with 500 hidden units on both the encoder/decoder.
|
||||||
If you want to train on GPU, you need to set, as an example:
|
If you want to train on GPU, you need to set, as an example:
|
||||||
CUDA_VISIBLE_DEVICES=1,3
|
`CUDA_VISIBLE_DEVICES=1,3`
|
||||||
`-world_size 2 -gpu_ranks 0 1` to use (say) GPU 1 and 3 on this node only.
|
`-world_size 2 -gpu_ranks 0 1` to use (say) GPU 1 and 3 on this node only.
|
||||||
To know more about distributed training on single or multi nodes, read the FAQ section.
|
To know more about distributed training on single or multi nodes, read the FAQ section.
|
||||||
|
|
||||||
### Step 3: Translate
|
### Step 3: Translate
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python translate.py -model demo-model_acc_XX.XX_ppl_XXX.XX_eX.pt -src data/src-test.txt -output pred.txt -replace_unk -verbose
|
onmt_translate -model demo-model_acc_XX.XX_ppl_XXX.XX_eX.pt -src data/src-test.txt -output pred.txt -replace_unk -verbose
|
||||||
```
|
```
|
||||||
|
|
||||||
Now you have a model which you can use to predict on new data. We do this by running beam search. This will output predictions into `pred.txt`.
|
Now you have a model which you can use to predict on new data. We do this by running beam search. This will output predictions into `pred.txt`.
|
||||||
|
@ -151,7 +162,7 @@ Major contributors are:
|
||||||
[Dylan Flaute](http://github.com/flauted (University of Dayton)
|
[Dylan Flaute](http://github.com/flauted (University of Dayton)
|
||||||
and more !
|
and more !
|
||||||
|
|
||||||
OpentNMT-py belongs to the OpenNMT project along with OpenNMT-Lua and OpenNMT-tf.
|
OpenNMT-py belongs to the OpenNMT project along with OpenNMT-Lua and OpenNMT-tf.
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
{
|
|
||||||
"models_root": "./available_models",
|
|
||||||
"models":[{
|
|
||||||
"model": "onmt-hien.pt",
|
|
||||||
"timeout": -1,
|
|
||||||
"on_timeout": "unload",
|
|
||||||
"model_root": "../model/",
|
|
||||||
"opt": {
|
|
||||||
"batch_size": 1,
|
|
||||||
"beam_size": 10
|
|
||||||
}
|
|
||||||
}]
|
|
||||||
}
|
|
|
@ -4,3 +4,4 @@ sphinxcontrib.mermaid
|
||||||
sphinx-rtd-theme
|
sphinx-rtd-theme
|
||||||
recommonmark
|
recommonmark
|
||||||
sphinx-argparse
|
sphinx-argparse
|
||||||
|
sphinx_markdown_tables
|
||||||
|
|
|
@ -8,7 +8,7 @@ the script is a slightly modified version of ylhsieh's one2.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
embeddings_to_torch.py [-h] [-emb_file_both EMB_FILE_BOTH]
|
embeddings_to_torch.py [-h] [-emb_file_both EMB_FILE_BOTH]
|
||||||
[-emb_file_enc EMB_FILE_ENC]
|
[-emb_file_enc EMB_FILE_ENC]
|
||||||
[-emb_file_dec EMB_FILE_DEC] -output_file
|
[-emb_file_dec EMB_FILE_DEC] -output_file
|
||||||
|
@ -16,23 +16,23 @@ embeddings_to_torch.py [-h] [-emb_file_both EMB_FILE_BOTH]
|
||||||
[-skip_lines SKIP_LINES]
|
[-skip_lines SKIP_LINES]
|
||||||
[-type {GloVe,word2vec}]
|
[-type {GloVe,word2vec}]
|
||||||
```
|
```
|
||||||
|
|
||||||
Run embeddings_to_torch.py -h for more usagecomplete info.
|
Run embeddings_to_torch.py -h for more usagecomplete info.
|
||||||
|
|
||||||
Example
|
### Example
|
||||||
|
|
||||||
|
1. Get GloVe files:
|
||||||
|
|
||||||
1) get GloVe files:
|
```shell
|
||||||
|
|
||||||
```
|
|
||||||
mkdir "glove_dir"
|
mkdir "glove_dir"
|
||||||
wget http://nlp.stanford.edu/data/glove.6B.zip
|
wget http://nlp.stanford.edu/data/glove.6B.zip
|
||||||
unzip glove.6B.zip -d "glove_dir"
|
unzip glove.6B.zip -d "glove_dir"
|
||||||
```
|
```
|
||||||
|
|
||||||
2) prepare data:
|
2. Prepare data:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
python preprocess.py \
|
onmt_preprocess \
|
||||||
-train_src data/train.src.txt \
|
-train_src data/train.src.txt \
|
||||||
-train_tgt data/train.tgt.txt \
|
-train_tgt data/train.tgt.txt \
|
||||||
-valid_src data/valid.src.txt \
|
-valid_src data/valid.src.txt \
|
||||||
|
@ -40,18 +40,18 @@ python preprocess.py \
|
||||||
-save_data data/data
|
-save_data data/data
|
||||||
```
|
```
|
||||||
|
|
||||||
3) prepare embeddings:
|
3. Prepare embeddings:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
./tools/embeddings_to_torch.py -emb_file_both "glove_dir/glove.6B.100d.txt" \
|
./tools/embeddings_to_torch.py -emb_file_both "glove_dir/glove.6B.100d.txt" \
|
||||||
-dict_file "data/data.vocab.pt" \
|
-dict_file "data/data.vocab.pt" \
|
||||||
-output_file "data/embeddings"
|
-output_file "data/embeddings"
|
||||||
```
|
```
|
||||||
|
|
||||||
4) train using pre-trained embeddings:
|
4. Train using pre-trained embeddings:
|
||||||
|
|
||||||
```
|
```shell
|
||||||
python train.py -save_model data/model \
|
onmt_train -save_model data/model \
|
||||||
-batch_size 64 \
|
-batch_size 64 \
|
||||||
-layers 2 \
|
-layers 2 \
|
||||||
-rnn_size 200 \
|
-rnn_size 200 \
|
||||||
|
@ -61,14 +61,13 @@ python train.py -save_model data/model \
|
||||||
-data data/data
|
-data data/data
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## How do I use the Transformer model?
|
||||||
## How do I use the Transformer model? Do you support multi-gpu?
|
|
||||||
|
|
||||||
The transformer model is very sensitive to hyperparameters. To run it
|
The transformer model is very sensitive to hyperparameters. To run it
|
||||||
effectively you need to set a bunch of different options that mimic the Google
|
effectively you need to set a bunch of different options that mimic the Google
|
||||||
setup. We have confirmed the following command can replicate their WMT results.
|
setup. We have confirmed the following command can replicate their WMT results.
|
||||||
|
|
||||||
```
|
```shell
|
||||||
python train.py -data /tmp/de2/data -save_model /tmp/extra \
|
python train.py -data /tmp/de2/data -save_model /tmp/extra \
|
||||||
-layers 6 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 \
|
-layers 6 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 \
|
||||||
-encoder_type transformer -decoder_type transformer -position_encoding \
|
-encoder_type transformer -decoder_type transformer -position_encoding \
|
||||||
|
@ -86,24 +85,34 @@ Here are what each of the parameters mean:
|
||||||
* `position_encoding`: add sinusoidal position encoding to each embedding
|
* `position_encoding`: add sinusoidal position encoding to each embedding
|
||||||
* `optim adam`, `decay_method noam`, `warmup_steps 8000`: use special learning rate.
|
* `optim adam`, `decay_method noam`, `warmup_steps 8000`: use special learning rate.
|
||||||
* `batch_type tokens`, `normalization tokens`, `accum_count 4`: batch and normalize based on number of tokens and not sentences. Compute gradients based on four batches.
|
* `batch_type tokens`, `normalization tokens`, `accum_count 4`: batch and normalize based on number of tokens and not sentences. Compute gradients based on four batches.
|
||||||
- `label_smoothing 0.1`: use label smoothing loss.
|
* `label_smoothing 0.1`: use label smoothing loss.
|
||||||
|
|
||||||
Multi GPU settings
|
## Do you support multi-gpu?
|
||||||
First you need to make sure you export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|
||||||
If you want to use GPU id 1 and 3 of your OS, you will need to export CUDA_VISIBLE_DEVICES=1,3
|
|
||||||
* `world_size 4 gpu_ranks 0 1 2 3`: This will use 4 GPU on this node only.
|
|
||||||
|
|
||||||
If you want to use 2 nodes with 2 GPU each, you need to set -master_ip and master_port, and
|
First you need to make sure you `export CUDA_VISIBLE_DEVICES=0,1,2,3`.
|
||||||
* `world_size 4 gpu_ranks 0 1`: on the first node
|
|
||||||
* `world_size 4 gpu_ranks 2 3`: on the second node
|
|
||||||
* `accum_count 2`: This will accumulate over 2 batches before updating parameters.
|
|
||||||
|
|
||||||
if you use a regular network card (1 Gbps) then we suggest to use a higher accum_count to minimize the inter-node communication.
|
If you want to use GPU id 1 and 3 of your OS, you will need to `export CUDA_VISIBLE_DEVICES=1,3`
|
||||||
|
|
||||||
|
Both `-world_size` and `-gpu_ranks` need to be set. E.g. `-world_size 4 -gpu_ranks 0 1 2 3` will use 4 GPU on this node only.
|
||||||
|
|
||||||
|
If you want to use 2 nodes with 2 GPU each, you need to set `-master_ip` and `-master_port`, and
|
||||||
|
|
||||||
|
* `-world_size 4 -gpu_ranks 0 1`: on the first node
|
||||||
|
* `-world_size 4 -gpu_ranks 2 3`: on the second node
|
||||||
|
* `-accum_count 2`: This will accumulate over 2 batches before updating parameters.
|
||||||
|
|
||||||
|
if you use a regular network card (1 Gbps) then we suggest to use a higher `-accum_count` to minimize the inter-node communication.
|
||||||
|
|
||||||
|
**Note:**
|
||||||
|
|
||||||
|
When training on several GPUs, you can't have them in 'Exclusive' compute mode (`nvidia-smi -c 3`).
|
||||||
|
|
||||||
|
The multi-gpu setup relies on a Producer/Consumer setup. This setup means there will be `2<n_gpu> + 1` processes spawned, with 2 processes per GPU, one for model training and one (Consumer) that hosts a `Queue` of batches that will be processed next. The additional process is the Producer, creating batches and sending them to the Consumers. This setup is beneficial for both wall time and memory, since it loads data shards 'in advance', and does not require to load it for each GPU process.
|
||||||
|
|
||||||
## How can I ensemble Models at inference?
|
## How can I ensemble Models at inference?
|
||||||
|
|
||||||
You can specify several models in the translate.py command line: -model model1_seed1 model2_seed2
|
You can specify several models in the translate.py command line: -model model1_seed1 model2_seed2
|
||||||
Bear in mind that your models must share the same traget vocabulary.
|
Bear in mind that your models must share the same target vocabulary.
|
||||||
|
|
||||||
## How can I weight different corpora at training?
|
## How can I weight different corpora at training?
|
||||||
|
|
||||||
|
@ -112,7 +121,8 @@ Bear in mind that your models must share the same traget vocabulary.
|
||||||
We introduced `-train_ids` which is a list of IDs that will be given to the preprocessed shards.
|
We introduced `-train_ids` which is a list of IDs that will be given to the preprocessed shards.
|
||||||
|
|
||||||
E.g. we have two corpora : `parallel.en` and `parallel.de` + `from_backtranslation.en` `from_backtranslation.de`, we can pass the following in the `preprocess.py` command:
|
E.g. we have two corpora : `parallel.en` and `parallel.de` + `from_backtranslation.en` `from_backtranslation.de`, we can pass the following in the `preprocess.py` command:
|
||||||
```
|
|
||||||
|
```shell
|
||||||
...
|
...
|
||||||
-train_src parallel.en from_backtranslation.en \
|
-train_src parallel.en from_backtranslation.en \
|
||||||
-train_tgt parallel.de from_backtranslation.de \
|
-train_tgt parallel.de from_backtranslation.de \
|
||||||
|
@ -120,19 +130,55 @@ E.g. we have two corpora : `parallel.en` and `parallel.de` + `from_backtranslat
|
||||||
-save_data my_data \
|
-save_data my_data \
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
and it will dump `my_data.train_A.X.pt` based on `parallel.en`//`parallel.de` and `my_data.train_B.X.pt` based on `from_backtranslation.en`//`from_backtranslation.de`.
|
and it will dump `my_data.train_A.X.pt` based on `parallel.en`//`parallel.de` and `my_data.train_B.X.pt` based on `from_backtranslation.en`//`from_backtranslation.de`.
|
||||||
|
|
||||||
### Training
|
### Training
|
||||||
|
|
||||||
We introduced `-data_ids` based on the same principle as above, as well as `-data_weights`, which is the list of the weight each corpus should have.
|
We introduced `-data_ids` based on the same principle as above, as well as `-data_weights`, which is the list of the weight each corpus should have.
|
||||||
E.g.
|
E.g.
|
||||||
```
|
|
||||||
|
```shell
|
||||||
...
|
...
|
||||||
-data my_data \
|
-data my_data \
|
||||||
-data_ids A B \
|
-data_ids A B \
|
||||||
-data_weights 1 7 \
|
-data_weights 1 7 \
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
will mean that we'll look for `my_data.train_A.*.pt` and `my_data.train_B.*.pt`, and that when building batches, we'll take 1 example from corpus A, then 7 examples from corpus B, and so on.
|
will mean that we'll look for `my_data.train_A.*.pt` and `my_data.train_B.*.pt`, and that when building batches, we'll take 1 example from corpus A, then 7 examples from corpus B, and so on.
|
||||||
|
|
||||||
**Warning**: This means that we'll load as many shards as we have `-data_ids`, in order to produce batches containing data from every corpus. It may be a good idea to reduce the `-shard_size` at preprocessing.
|
**Warning**: This means that we'll load as many shards as we have `-data_ids`, in order to produce batches containing data from every corpus. It may be a good idea to reduce the `-shard_size` at preprocessing.
|
||||||
|
|
||||||
|
## Can I get word alignment while translating?
|
||||||
|
|
||||||
|
### Raw alignments from averaging Transformer attention heads
|
||||||
|
|
||||||
|
Currently, we support producing word alignment while translating for Transformer based models. Using `-report_align` when calling `translate.py` will output the inferred alignments in Pharaoh format. Those alignments are computed from an argmax on the average of the attention heads of the *second to last* decoder layer. The resulting alignment src-tgt (Pharaoh) will be pasted to the translation sentence, separated by ` ||| `.
|
||||||
|
Note: The *second to last* default behaviour was empirically determined. It is not the same as the paper (they take the *penultimate* layer), probably because of light differences in the architecture.
|
||||||
|
|
||||||
|
* alignments use the standard "Pharaoh format", where a pair `i-j` indicates the i<sub>th</sub> word of source language is aligned to j<sub>th</sub> word of target language.
|
||||||
|
* Example: {'src': 'das stimmt nicht !'; 'output': 'that is not true ! ||| 0-0 0-1 1-2 2-3 1-4 1-5 3-6'}
|
||||||
|
* Using the`-tgt` option when calling `translate.py`, we output alignments between the source and the gold target rather than the inferred target, assuming we're doing evaluation.
|
||||||
|
* To convert subword alignments to word alignments, or symetrize bidirectional alignments, please refer to the [lilt scripts](https://github.com/lilt/alignment-scripts).
|
||||||
|
|
||||||
|
### Supervised learning on a specific head
|
||||||
|
|
||||||
|
The quality of output alignments can be further improved by providing reference alignments while training. This will invoke multi-task learning on translation and alignment. This is an implementation based on the paper [Jointly Learning to Align and Translate with Transformer Models](https://arxiv.org/abs/1909.02074).
|
||||||
|
|
||||||
|
The data need to be preprocessed with the reference alignments in order to learn the supervised task.
|
||||||
|
|
||||||
|
When calling `preprocess.py`, add:
|
||||||
|
|
||||||
|
* `--train_align <path>`: path(s) to the training alignments in Pharaoh format
|
||||||
|
* `--valid_align <path>`: path to the validation set alignments in Pharaoh format (optional).
|
||||||
|
The reference alignment file(s) could be generated by [GIZA++](https://github.com/moses-smt/mgiza/) or [fast_align](https://github.com/clab/fast_align).
|
||||||
|
|
||||||
|
Note: There should be no blank lines in the alignment files provided.
|
||||||
|
|
||||||
|
Options to learn such alignments are:
|
||||||
|
|
||||||
|
* `-lambda_align`: set the value > 0.0 to enable joint align training, the paper suggests 0.05;
|
||||||
|
* `-alignment_layer`: indicate the index of the decoder layer;
|
||||||
|
* `-alignment_heads`: number of alignment heads for the alignment task - should be set to 1 for the supervised task, and preferably kept to default (or same as `num_heads`) for the average task;
|
||||||
|
* `-full_context_alignment`: do full context decoder pass (no future mask) when computing alignments. This will slow down the training (~12% in terms of tok/s) but will be beneficial to generate better alignment.
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
For this example, we will assume that we have run preprocess to
|
For this example, we will assume that we have run preprocess to
|
||||||
create our datasets. For instance
|
create our datasets. For instance
|
||||||
|
|
||||||
> python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 10000 -tgt_vocab_size 10000
|
> onmt_preprocess -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 10000 -tgt_vocab_size 10000
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -35,8 +35,8 @@ For CNN-DM we follow See et al. [2] and additionally truncate the source length
|
||||||
|
|
||||||
(1) CNN-DM
|
(1) CNN-DM
|
||||||
|
|
||||||
```
|
```bash
|
||||||
python preprocess.py -train_src data/cnndm/train.txt.src \
|
onmt_preprocess -train_src data/cnndm/train.txt.src \
|
||||||
-train_tgt data/cnndm/train.txt.tgt.tagged \
|
-train_tgt data/cnndm/train.txt.tgt.tagged \
|
||||||
-valid_src data/cnndm/val.txt.src \
|
-valid_src data/cnndm/val.txt.src \
|
||||||
-valid_tgt data/cnndm/val.txt.tgt.tagged \
|
-valid_tgt data/cnndm/val.txt.tgt.tagged \
|
||||||
|
@ -52,8 +52,8 @@ python preprocess.py -train_src data/cnndm/train.txt.src \
|
||||||
|
|
||||||
(2) Gigaword
|
(2) Gigaword
|
||||||
|
|
||||||
```
|
```bash
|
||||||
python preprocess.py -train_src data/giga/train.article.txt \
|
onmt_preprocess -train_src data/giga/train.article.txt \
|
||||||
-train_tgt data/giga/train.title.txt \
|
-train_tgt data/giga/train.title.txt \
|
||||||
-valid_src data/giga/valid.article.txt \
|
-valid_src data/giga/valid.article.txt \
|
||||||
-valid_tgt data/giga/valid.title.txt \
|
-valid_tgt data/giga/valid.title.txt \
|
||||||
|
@ -87,8 +87,8 @@ We additionally set the maximum norm of the gradient to 2, and renormalize if th
|
||||||
|
|
||||||
(1) CNN-DM
|
(1) CNN-DM
|
||||||
|
|
||||||
```
|
```bash
|
||||||
python train.py -save_model models/cnndm \
|
onmt_train -save_model models/cnndm \
|
||||||
-data data/cnndm/CNNDM \
|
-data data/cnndm/CNNDM \
|
||||||
-copy_attn \
|
-copy_attn \
|
||||||
-global_attention mlp \
|
-global_attention mlp \
|
||||||
|
@ -116,8 +116,8 @@ python train.py -save_model models/cnndm \
|
||||||
|
|
||||||
The following script trains the transformer model on CNN-DM
|
The following script trains the transformer model on CNN-DM
|
||||||
|
|
||||||
```
|
```bash
|
||||||
python -u train.py -data data/cnndm/CNNDM \
|
onmt_train -data data/cnndm/CNNDM \
|
||||||
-save_model models/cnndm \
|
-save_model models/cnndm \
|
||||||
-layers 4 \
|
-layers 4 \
|
||||||
-rnn_size 512 \
|
-rnn_size 512 \
|
||||||
|
@ -152,7 +152,7 @@ python -u train.py -data data/cnndm/CNNDM \
|
||||||
Gigaword can be trained equivalently. As a baseline, we show a model trained with the following command:
|
Gigaword can be trained equivalently. As a baseline, we show a model trained with the following command:
|
||||||
|
|
||||||
```
|
```
|
||||||
python train.py -data data/giga/GIGA \
|
onmt_train -data data/giga/GIGA \
|
||||||
-save_model models/giga \
|
-save_model models/giga \
|
||||||
-copy_attn \
|
-copy_attn \
|
||||||
-reuse_copy_attn \
|
-reuse_copy_attn \
|
||||||
|
@ -177,7 +177,7 @@ During inference, we use beam-search with a beam-size of 10. We also added speci
|
||||||
(1) CNN-DM
|
(1) CNN-DM
|
||||||
|
|
||||||
```
|
```
|
||||||
python translate.py -gpu X \
|
onmt_translate -gpu X \
|
||||||
-batch_size 20 \
|
-batch_size 20 \
|
||||||
-beam_size 10 \
|
-beam_size 10 \
|
||||||
-model models/cnndm... \
|
-model models/cnndm... \
|
||||||
|
@ -221,8 +221,6 @@ For evaluation of large test sets such as Gigaword, we use the a parallel python
|
||||||
|
|
||||||
### Scores and Models
|
### Scores and Models
|
||||||
|
|
||||||
The website generator has trouble rendering tables, if you can't read the results, please go [here](https://github.com/OpenNMT/OpenNMT-py/blob/master/docs/source/Summarization.md) for correct format.
|
|
||||||
|
|
||||||
#### CNN-DM
|
#### CNN-DM
|
||||||
|
|
||||||
| Model Type | Model | R1 R | R1 P | R1 F | R2 R | R2 P | R2 F | RL R | RL P | RL F |
|
| Model Type | Model | R1 R | R1 P | R1 F | R2 R | R2 P | R2 F | RL R | RL P | RL F |
|
||||||
|
@ -231,7 +229,7 @@ The website generator has trouble rendering tables, if you can't read the result
|
||||||
| Pointer-Generator [2] | [link](https://github.com/abisee/pointer-generator) | 37.76 | 37.60| 36.44| 16.31| 16.12| 15.66| 34.66| 34.46| 33.42 |
|
| Pointer-Generator [2] | [link](https://github.com/abisee/pointer-generator) | 37.76 | 37.60| 36.44| 16.31| 16.12| 15.66| 34.66| 34.46| 33.42 |
|
||||||
| OpenNMT BRNN (1 layer, emb 128, hid 512) | [link](https://s3.amazonaws.com/opennmt-models/Summary/ada6_bridge_oldcopy_tagged_acc_54.17_ppl_11.17_e20.pt) | 40.90| 40.20| 39.02| 17.91| 17.99| 17.25| 37.76 | 37.18| 36.05 |
|
| OpenNMT BRNN (1 layer, emb 128, hid 512) | [link](https://s3.amazonaws.com/opennmt-models/Summary/ada6_bridge_oldcopy_tagged_acc_54.17_ppl_11.17_e20.pt) | 40.90| 40.20| 39.02| 17.91| 17.99| 17.25| 37.76 | 37.18| 36.05 |
|
||||||
| OpenNMT BRNN (1 layer, emb 128, hid 512, shared embeddings) | [link](https://s3.amazonaws.com/opennmt-models/Summary/ada6_bridge_oldcopy_tagged_share_acc_54.50_ppl_10.89_e20.pt) | 38.59 | 40.60 | 37.97 | 16.75 | 17.93 | 16.59 | 35.67 | 37.60 | 35.13 |
|
| OpenNMT BRNN (1 layer, emb 128, hid 512, shared embeddings) | [link](https://s3.amazonaws.com/opennmt-models/Summary/ada6_bridge_oldcopy_tagged_share_acc_54.50_ppl_10.89_e20.pt) | 38.59 | 40.60 | 37.97 | 16.75 | 17.93 | 16.59 | 35.67 | 37.60 | 35.13 |
|
||||||
| OpenNMT BRNN (2 layer, emb 256, hid 1024) | [link](https://s3.amazonaws.com/opennmt-models/Summary/ada6_bridge_oldcopy_tagged_larger_acc_54.84_ppl_10.58_e17.ptt) | 40.41 | 40.94 | 39.12 | 17.76 | 18.38 | 17.35 | 37.27 | 37.83 | 36.12 |
|
| OpenNMT BRNN (2 layer, emb 256, hid 1024) | [link](https://s3.amazonaws.com/opennmt-models/Summary/ada6_bridge_oldcopy_tagged_larger_acc_54.84_ppl_10.58_e17.pt) | 40.41 | 40.94 | 39.12 | 17.76 | 18.38 | 17.35 | 37.27 | 37.83 | 36.12 |
|
||||||
| OpenNMT Transformer | [link](https://s3.amazonaws.com/opennmt-models/sum_transformer_model_acc_57.25_ppl_9.22_e16.pt) | 40.31 | 41.09 | 39.25 | 17.97 | 18.46 | 17.54 | 37.41 | 38.18 | 36.45 |
|
| OpenNMT Transformer | [link](https://s3.amazonaws.com/opennmt-models/sum_transformer_model_acc_57.25_ppl_9.22_e16.pt) | 40.31 | 41.09 | 39.25 | 17.97 | 18.46 | 17.54 | 37.41 | 38.18 | 36.45 |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,8 @@ extensions = ['sphinx.ext.autodoc',
|
||||||
'sphinx.ext.napoleon',
|
'sphinx.ext.napoleon',
|
||||||
'sphinxcontrib.mermaid',
|
'sphinxcontrib.mermaid',
|
||||||
'sphinxcontrib.bibtex',
|
'sphinxcontrib.bibtex',
|
||||||
'sphinxarg.ext']
|
'sphinxarg.ext',
|
||||||
|
'sphinx_markdown_tables']
|
||||||
|
|
||||||
# Show base classes
|
# Show base classes
|
||||||
autodoc_default_options = {
|
autodoc_default_options = {
|
||||||
|
|
|
@ -17,19 +17,19 @@ Step 1. Preprocess the data.
|
||||||
```bash
|
```bash
|
||||||
for l in en de; do for f in data/multi30k/*.$l; do if [[ "$f" != *"test"* ]]; then sed -i "$ d" $f; fi; done; done
|
for l in en de; do for f in data/multi30k/*.$l; do if [[ "$f" != *"test"* ]]; then sed -i "$ d" $f; fi; done; done
|
||||||
for l in en de; do for f in data/multi30k/*.$l; do perl tools/tokenizer.perl -a -no-escape -l $l -q < $f > $f.atok; done; done
|
for l in en de; do for f in data/multi30k/*.$l; do perl tools/tokenizer.perl -a -no-escape -l $l -q < $f > $f.atok; done; done
|
||||||
python preprocess.py -train_src data/multi30k/train.en.atok -train_tgt data/multi30k/train.de.atok -valid_src data/multi30k/val.en.atok -valid_tgt data/multi30k/val.de.atok -save_data data/multi30k.atok.low -lower
|
onmt_preprocess -train_src data/multi30k/train.en.atok -train_tgt data/multi30k/train.de.atok -valid_src data/multi30k/val.en.atok -valid_tgt data/multi30k/val.de.atok -save_data data/multi30k.atok.low -lower
|
||||||
```
|
```
|
||||||
|
|
||||||
Step 2. Train the model.
|
Step 2. Train the model.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python train.py -data data/multi30k.atok.low -save_model multi30k_model -gpu_ranks 0
|
onmt_train -data data/multi30k.atok.low -save_model multi30k_model -gpu_ranks 0
|
||||||
```
|
```
|
||||||
|
|
||||||
Step 3. Translate sentences.
|
Step 3. Translate sentences.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python translate.py -gpu 0 -model multi30k_model_*_e13.pt -src data/multi30k/test2016.en.atok -tgt data/multi30k/test2016.de.atok -replace_unk -verbose -output multi30k.test.pred.atok
|
onmt_translate -gpu 0 -model multi30k_model_*_e13.pt -src data/multi30k/test2016.en.atok -tgt data/multi30k/test2016.de.atok -replace_unk -verbose -output multi30k.test.pred.atok
|
||||||
```
|
```
|
||||||
|
|
||||||
And evaluate
|
And evaluate
|
||||||
|
|
|
@ -0,0 +1,94 @@
|
||||||
|
# Gated Graph Sequence Neural Networks
|
||||||
|
|
||||||
|
Graph-to-sequence networks allow information represtable as a graph (such as an annotated NLP sentence or computer code structure as an AST) to be connected to a sequence generator to produce output which can benefit from the graph structure of the input.
|
||||||
|
|
||||||
|
The training option `-encoder_type ggnn` implements a GGNN (Gated Graph Neural Network) based on github.com/JamesChuanggg/ggnn.pytorch.git which is based on the paper "Gated Graph Sequence Neural Networks" by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel.
|
||||||
|
|
||||||
|
The ggnn encoder is used for program equivalence proof generation in the paper <a href="https://arxiv.org/abs/2002.06799">Equivalence of Dataflow Graphs via Rewrite Rules Using a Graph-to-Sequence Neural Model</a>. That paper shows the benefit of the graph-to-sequence model over a sequence-to-sequence model for this problem which can be well represented with graphical input. The integration of the ggnn network into the <a href="https://github.com/OpenNMT/OpenNMT-py/">OpenNMT-py</a> system supports attention on the nodes as well as a copy mechanism.
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
* There are no additional dependencies beyond the rnn-to-rnn sequeence2sequence requirements.
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
To get started, we provide a toy graph-to-sequence example. We assume that the working directory is `OpenNMT-py` throughout this document.
|
||||||
|
|
||||||
|
0) Download the data to a sibling directory.
|
||||||
|
|
||||||
|
```
|
||||||
|
cd ..
|
||||||
|
git clone https://github.com/SteveKommrusch/OpenNMT-py-ggnn-example
|
||||||
|
source OpenNMT-py-ggnn-example/env.sh
|
||||||
|
cd OpenNMT-py
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
1) Preprocess the data.
|
||||||
|
|
||||||
|
```
|
||||||
|
python preprocess.py -train_src $data_path/src-train.txt -train_tgt $data_path/tgt-train.txt -valid_src $data_path/src-val.txt -valid_tgt $data_path/tgt-val.txt -src_seq_length 1000 -tgt_seq_length 30 -src_vocab $data_path/srcvocab.txt -tgt_vocab $data_path/tgtvocab.txt -dynamic_dict -save_data $data_path/final 2>&1 > $data_path/preprocess.out
|
||||||
|
```
|
||||||
|
|
||||||
|
2) Train the model.
|
||||||
|
|
||||||
|
```
|
||||||
|
python train.py -data $data_path/final -encoder_type ggnn -layers 2 -decoder_type rnn -rnn_size 256 -learning_rate 0.1 -start_decay_steps 5000 -learning_rate_decay 0.8 -global_attention general -batch_size 32 -word_vec_size 256 -bridge -train_steps 10000 -gpu_ranks 0 -save_checkpoint_steps 5000 -save_model $data_path/final-model -src_vocab $data_path/srcvocab.txt -n_edge_types 9 -state_dim 256 -n_steps 10 -n_node 64 > $data_path/train.final.out
|
||||||
|
```
|
||||||
|
|
||||||
|
3) Translate the graph of 2 equivalent linear algebra expressions into the axiom list which proves them equivalent.
|
||||||
|
|
||||||
|
```
|
||||||
|
python translate.py -model $data_path/final-model_step_10000.pt -src $data_path/src-test.txt -beam_size 5 -n_best 5 -gpu 0 -output $data_path/pred-test_beam5.txt -dynamic_dict 2>&1 > $data_path/translate5.out
|
||||||
|
```
|
||||||
|
|
||||||
|
### Graph data format
|
||||||
|
|
||||||
|
The GGNN implementation leverages the sequence processing and vocabulary
|
||||||
|
interface of OpenNMT. Each graph is provided on an input line, much like
|
||||||
|
a sentence is provided on an input line. A graph nearal network input line
|
||||||
|
includes `sentence tokens`, `feature values`, and `edges` separated by
|
||||||
|
`<EOT>` (end of tokens) tokens. Below is example of the input for a pair
|
||||||
|
of algebraic equations structured as a graph:
|
||||||
|
|
||||||
|
```
|
||||||
|
Sentence tokens Feature values Edges
|
||||||
|
--------------- ------------------ -------------------------------------------------------
|
||||||
|
- - - 0 a a b b <EOT> 0 1 2 3 4 4 2 3 12 <EOT> 0 2 1 3 2 4 , 0 6 1 7 2 5 , 0 4 , 0 5 , , , , 8 0 , 8 1
|
||||||
|
```
|
||||||
|
|
||||||
|
The equations being represented are `((a - a) - b)` and `(0 - b)`, the
|
||||||
|
`sentence tokens` of which are provided before the first `<EOT>`. After
|
||||||
|
the first `<EOT>`, the `features values` are provided. These are extra
|
||||||
|
flags with information on each node in the graph. In this case, the 8
|
||||||
|
sentence tokens have feature flags ranging from 0 to 4; the 9th feature
|
||||||
|
flag defines a 9th node in the graph which does not have sentence token
|
||||||
|
information, just feature data. Nodes with any non-number flag (such as
|
||||||
|
`-` or `.`) will not have a feature added. Multiple groups of features
|
||||||
|
can be provided by using the `,` delimiter between the first and second
|
||||||
|
'<EOT>' tokens. After the second `<EOT>` token, edge information is provided.
|
||||||
|
Edge data is given as node pairs, hence `<EOT> 0 2 1 3` indicates that there
|
||||||
|
are edges from node 0 to node 2 and from node 1 to node 3. The GGNN supports
|
||||||
|
multiple edge types (which result mathematically in multiple weight matrices
|
||||||
|
for the model) and the edge types are separated by `,` tokens after the
|
||||||
|
second `<EOT>` token.
|
||||||
|
|
||||||
|
Note that the source vocabulary file needs to include the '<EOT>' token,
|
||||||
|
the ',' token, and all of the numbers used for feature flags and node
|
||||||
|
identifiers in the edge list.
|
||||||
|
|
||||||
|
|
||||||
|
### Options
|
||||||
|
|
||||||
|
* `-rnn_type (str)`: style of recurrent unit to use, one of [LSTM]
|
||||||
|
* `-state_dim (int)`: Number of state dimensions in nodes
|
||||||
|
* `-n_edge_types (int)`: Number of edge types
|
||||||
|
* `-bidir_edges (bool)`: True if reverse edges should be automatically created
|
||||||
|
* `-n_node (int)`: Max nodes in graph
|
||||||
|
* `-bridge_extra_node (bool)`: True indicates only the vector from the 1st extra node (after token listing) should be used for decoder initialization; False indicates all node vectors should be averaged together for decoder initialization
|
||||||
|
* `-n_steps (int)`: Steps to advance graph encoder for stabilization
|
||||||
|
* `-src_vocab (int)`: Path to source vocabulary
|
||||||
|
|
||||||
|
### Acknowledgement
|
||||||
|
|
||||||
|
This gated graph neural network is leveraged from github.com/JamesChuanggg/ggnn.pytorch.git which is based on the paper "Gated Graph Sequence Neural Networks" by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel.
|
|
@ -27,31 +27,51 @@ Im2Text consists of four commands:
|
||||||
|
|
||||||
0) Download the data.
|
0) Download the data.
|
||||||
|
|
||||||
```
|
```bash
|
||||||
wget -O data/im2text.tgz http://lstm.seas.harvard.edu/latex/im2text_small.tgz; tar zxf data/im2text.tgz -C data/
|
wget -O data/im2text.tgz http://lstm.seas.harvard.edu/latex/im2text_small.tgz; tar zxf data/im2text.tgz -C data/
|
||||||
```
|
```
|
||||||
|
|
||||||
1) Preprocess the data.
|
1) Preprocess the data.
|
||||||
|
|
||||||
```
|
```bash
|
||||||
python preprocess.py -data_type img -src_dir data/im2text/images/ -train_src data/im2text/src-train.txt \
|
onmt_preprocess -data_type img \
|
||||||
|
-src_dir data/im2text/images/ \
|
||||||
|
-train_src data/im2text/src-train.txt \
|
||||||
-train_tgt data/im2text/tgt-train.txt -valid_src data/im2text/src-val.txt \
|
-train_tgt data/im2text/tgt-train.txt -valid_src data/im2text/src-val.txt \
|
||||||
-valid_tgt data/im2text/tgt-val.txt -save_data data/im2text/demo \
|
-valid_tgt data/im2text/tgt-val.txt -save_data data/im2text/demo \
|
||||||
-tgt_seq_length 150 -tgt_words_min_frequency 2 -shard_size 500 -image_channel_size 1
|
-tgt_seq_length 150 \
|
||||||
|
-tgt_words_min_frequency 2 \
|
||||||
|
-shard_size 500 \
|
||||||
|
-image_channel_size 1
|
||||||
```
|
```
|
||||||
|
|
||||||
2) Train the model.
|
2) Train the model.
|
||||||
|
|
||||||
```
|
```bash
|
||||||
python train.py -model_type img -data data/im2text/demo -save_model demo-model -gpu_ranks 0 -batch_size 20 \
|
onmt_train -model_type img \
|
||||||
-max_grad_norm 20 -learning_rate 0.1 -word_vec_size 80 -encoder_type brnn -image_channel_size 1
|
-data data/im2text/demo \
|
||||||
|
-save_model demo-model \
|
||||||
|
-gpu_ranks 0 \
|
||||||
|
-batch_size 20 \
|
||||||
|
-max_grad_norm 20 \
|
||||||
|
-learning_rate 0.1 \
|
||||||
|
-word_vec_size 80 \
|
||||||
|
-encoder_type brnn \
|
||||||
|
-image_channel_size 1
|
||||||
```
|
```
|
||||||
|
|
||||||
3) Translate the images.
|
3) Translate the images.
|
||||||
|
|
||||||
```
|
```bash
|
||||||
python translate.py -data_type img -model demo-model_acc_x_ppl_x_e13.pt -src_dir data/im2text/images \
|
onmt_translate -data_type img \
|
||||||
-src data/im2text/src-test.txt -output pred.txt -max_length 150 -beam_size 5 -gpu 0 -verbose
|
-model demo-model_acc_x_ppl_x_e13.pt \
|
||||||
|
-src_dir data/im2text/images \
|
||||||
|
-src data/im2text/src-test.txt \
|
||||||
|
-output pred.txt \
|
||||||
|
-max_length 150 \
|
||||||
|
-beam_size 5 \
|
||||||
|
-gpu 0 \
|
||||||
|
-verbose
|
||||||
```
|
```
|
||||||
|
|
||||||
The above dataset is sampled from the [im2latex-100k-dataset](http://lstm.seas.harvard.edu/latex/im2text.tgz). We provide a trained model [[link]](http://lstm.seas.harvard.edu/latex/py-model.pt) on this dataset.
|
The above dataset is sampled from the [im2latex-100k-dataset](http://lstm.seas.harvard.edu/latex/im2text.tgz). We provide a trained model [[link]](http://lstm.seas.harvard.edu/latex/py-model.pt) on this dataset.
|
||||||
|
|
|
@ -6,20 +6,22 @@ This portal provides a detailed documentation of the OpenNMT toolkit. It describ
|
||||||
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
Install from `pip`:
|
||||||
1\. [Install PyTorch](http://pytorch.org/)
|
Install `OpenNMT-py` from `pip`:
|
||||||
|
|
||||||
2\. Clone the OpenNMT-py repository:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/OpenNMT/OpenNMT-py
|
pip install OpenNMT-py
|
||||||
cd OpenNMT-py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
3\. Install required libraries
|
or from the sources:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
git clone https://github.com/OpenNMT/OpenNMT-py.git
|
||||||
|
cd OpenNMT-py
|
||||||
|
python setup.py install
|
||||||
|
```
|
||||||
|
|
||||||
|
*(Optionnal)* some advanced features (e.g. working audio, image or pretrained models) requires extra packages, you can install it with:
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.opt.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
And you are ready to go! Take a look at the [quickstart](quickstart) to familiarize yourself with the main training workflow.
|
And you are ready to go! Take a look at the [quickstart](quickstart) to familiarize yourself with the main training workflow.
|
||||||
|
|
|
@ -25,9 +25,9 @@ Decoding Strategies
|
||||||
.. autoclass:: onmt.translate.BeamSearch
|
.. autoclass:: onmt.translate.BeamSearch
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
.. autofunction:: onmt.translate.random_sampling.sample_with_temperature
|
.. autofunction:: onmt.translate.greedy_search.sample_with_temperature
|
||||||
|
|
||||||
.. autoclass:: onmt.translate.RandomSampling
|
.. autoclass:: onmt.translate.GreedySearch
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
Scoring
|
Scoring
|
||||||
|
|
|
@ -2,6 +2,6 @@ Preprocess
|
||||||
==========
|
==========
|
||||||
|
|
||||||
.. argparse::
|
.. argparse::
|
||||||
:filename: ../preprocess.py
|
:filename: ../onmt/bin/preprocess.py
|
||||||
:func: _get_parser
|
:func: _get_parser
|
||||||
:prog: preprocess.py
|
:prog: preprocess.py
|
|
@ -2,6 +2,6 @@ Server
|
||||||
=========
|
=========
|
||||||
|
|
||||||
.. argparse::
|
.. argparse::
|
||||||
:filename: ../server.py
|
:filename: ../onmt/bin/server.py
|
||||||
:func: _get_parser
|
:func: _get_parser
|
||||||
:prog: server.py
|
:prog: server.py
|
|
@ -2,6 +2,6 @@ Train
|
||||||
=====
|
=====
|
||||||
|
|
||||||
.. argparse::
|
.. argparse::
|
||||||
:filename: ../train.py
|
:filename: ../onmt/bin/train.py
|
||||||
:func: _get_parser
|
:func: _get_parser
|
||||||
:prog: train.py
|
:prog: train.py
|
|
@ -2,6 +2,6 @@ Translate
|
||||||
=========
|
=========
|
||||||
|
|
||||||
.. argparse::
|
.. argparse::
|
||||||
:filename: ../translate.py
|
:filename: ../onmt/bin/translate.py
|
||||||
:func: _get_parser
|
:func: _get_parser
|
||||||
:prog: translate.py
|
:prog: translate.py
|
|
@ -6,7 +6,7 @@
|
||||||
### Step 1: Preprocess the data
|
### Step 1: Preprocess the data
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo
|
onmt_preprocess -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo
|
||||||
```
|
```
|
||||||
|
|
||||||
We will be working with some example data in `data/` folder.
|
We will be working with some example data in `data/` folder.
|
||||||
|
@ -30,7 +30,7 @@ Federal Master Trainer and Senior Instructor of the Italian Federation of Aerobi
|
||||||
### Step 2: Train the model
|
### Step 2: Train the model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python train.py -data data/demo -save_model demo-model
|
onmt_train -data data/demo -save_model demo-model
|
||||||
```
|
```
|
||||||
|
|
||||||
The main train command is quite simple. Minimally it takes a data file
|
The main train command is quite simple. Minimally it takes a data file
|
||||||
|
@ -44,7 +44,7 @@ To know more about distributed training on single or multi nodes, read the FAQ s
|
||||||
### Step 3: Translate
|
### Step 3: Translate
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python translate.py -model demo-model_XYZ.pt -src data/src-test.txt -output pred.txt -replace_unk -verbose
|
onmt_translate -model demo-model_XYZ.pt -src data/src-test.txt -output pred.txt -replace_unk -verbose
|
||||||
```
|
```
|
||||||
|
|
||||||
Now you have a model which you can use to predict on new data. We do this by running beam search. This will output predictions into `pred.txt`.
|
Now you have a model which you can use to predict on new data. We do this by running beam search. This will output predictions into `pred.txt`.
|
||||||
|
|
|
@ -20,6 +20,22 @@
|
||||||
year={2016}
|
year={2016}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@inproceedings{Li2016
|
||||||
|
author = {Yujia Li and
|
||||||
|
Daniel Tarlow and
|
||||||
|
Marc Brockschmidt and
|
||||||
|
Richard S. Zemel},
|
||||||
|
title = {Gated Graph Sequence Neural Networks},
|
||||||
|
booktitle = {4th International Conference on Learning Representations, {ICLR} 2016,
|
||||||
|
San Juan, Puerto Rico, May 2-4, 2016, Conference Track Proceedings},
|
||||||
|
year = {2016},
|
||||||
|
crossref = {DBLP:conf/iclr/2016},
|
||||||
|
url = {http://arxiv.org/abs/1511.05493},
|
||||||
|
timestamp = {Thu, 25 Jul 2019 14:25:40 +0200},
|
||||||
|
biburl = {https://dblp.org/rec/journals/corr/LiTBZ15.bib},
|
||||||
|
bibsource = {dblp computer science bibliography, https://dblp.org}
|
||||||
|
}
|
||||||
|
|
||||||
@inproceedings{Bahdanau2015,
|
@inproceedings{Bahdanau2015,
|
||||||
archivePrefix = {arXiv},
|
archivePrefix = {arXiv},
|
||||||
arxivId = {1409.0473},
|
arxivId = {1409.0473},
|
||||||
|
@ -435,3 +451,33 @@ year = {2015}
|
||||||
biburl = {https://dblp.org/rec/bib/journals/corr/MartinsA16},
|
biburl = {https://dblp.org/rec/bib/journals/corr/MartinsA16},
|
||||||
bibsource = {dblp computer science bibliography, https://dblp.org}
|
bibsource = {dblp computer science bibliography, https://dblp.org}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@inproceedings{garg2019jointly,
|
||||||
|
title = {Jointly Learning to Align and Translate with Transformer Models},
|
||||||
|
author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias},
|
||||||
|
booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)},
|
||||||
|
address = {Hong Kong},
|
||||||
|
month = {November},
|
||||||
|
url = {https://arxiv.org/abs/1909.02074},
|
||||||
|
year = {2019},
|
||||||
|
}
|
||||||
|
|
||||||
|
@inproceedings{DeeperTransformer,
|
||||||
|
title = "Learning Deep Transformer Models for Machine Translation",
|
||||||
|
author = "Wang, Qiang and
|
||||||
|
Li, Bei and
|
||||||
|
Xiao, Tong and
|
||||||
|
Zhu, Jingbo and
|
||||||
|
Li, Changliang and
|
||||||
|
Wong, Derek F. and
|
||||||
|
Chao, Lidia S.",
|
||||||
|
booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
|
||||||
|
month = jul,
|
||||||
|
year = "2019",
|
||||||
|
address = "Florence, Italy",
|
||||||
|
publisher = "Association for Computational Linguistics",
|
||||||
|
url = "https://www.aclweb.org/anthology/P19-1176",
|
||||||
|
doi = "10.18653/v1/P19-1176",
|
||||||
|
pages = "1810--1822",
|
||||||
|
abstract = "Transformer is the state-of-the-art model in recent machine translation evaluations. Two strands of research are promising to improve models of this kind: the first uses wide networks (a.k.a. Transformer-Big) and has been the de facto standard for development of the Transformer system, and the other uses deeper language representation but faces the difficulty arising from learning deep networks. Here, we continue the line of research on the latter. We claim that a truly deep Transformer model can surpass the Transformer-Big counterpart by 1) proper use of layer normalization and 2) a novel way of passing the combination of previous layers to the next. On WMT{'}16 English-German and NIST OpenMT{'}12 Chinese-English tasks, our deep system (30/25-layer encoder) outperforms the shallow Transformer-Big/Base baseline (6-layer encoder) by 0.4-2.4 BLEU points. As another bonus, the deep model is 1.6X smaller in size and 3X faster in training than Transformer-Big.",
|
||||||
|
}
|
||||||
|
|
|
@ -23,19 +23,19 @@ wget -O data/speech.tgz http://lstm.seas.harvard.edu/latex/speech.tgz; tar zxf d
|
||||||
1) Preprocess the data.
|
1) Preprocess the data.
|
||||||
|
|
||||||
```
|
```
|
||||||
python preprocess.py -data_type audio -src_dir data/speech/an4_dataset -train_src data/speech/src-train.txt -train_tgt data/speech/tgt-train.txt -valid_src data/speech/src-val.txt -valid_tgt data/speech/tgt-val.txt -shard_size 300 -save_data data/speech/demo
|
onmt_preprocess -data_type audio -src_dir data/speech/an4_dataset -train_src data/speech/src-train.txt -train_tgt data/speech/tgt-train.txt -valid_src data/speech/src-val.txt -valid_tgt data/speech/tgt-val.txt -shard_size 300 -save_data data/speech/demo
|
||||||
```
|
```
|
||||||
|
|
||||||
2) Train the model.
|
2) Train the model.
|
||||||
|
|
||||||
```
|
```
|
||||||
python train.py -model_type audio -enc_rnn_size 512 -dec_rnn_size 512 -audio_enc_pooling 1,1,2,2 -dropout 0 -enc_layers 4 -dec_layers 1 -rnn_type LSTM -data data/speech/demo -save_model demo-model -global_attention mlp -gpu_ranks 0 -batch_size 8 -optim adam -max_grad_norm 100 -learning_rate 0.0003 -learning_rate_decay 0.8 -train_steps 100000
|
onmt_train -model_type audio -enc_rnn_size 512 -dec_rnn_size 512 -audio_enc_pooling 1,1,2,2 -dropout 0 -enc_layers 4 -dec_layers 1 -rnn_type LSTM -data data/speech/demo -save_model demo-model -global_attention mlp -gpu_ranks 0 -batch_size 8 -optim adam -max_grad_norm 100 -learning_rate 0.0003 -learning_rate_decay 0.8 -train_steps 100000
|
||||||
```
|
```
|
||||||
|
|
||||||
3) Translate the speechs.
|
3) Translate the speechs.
|
||||||
|
|
||||||
```
|
```
|
||||||
python translate.py -data_type audio -model demo-model_acc_x_ppl_x_e13.pt -src_dir data/speech/an4_dataset -src data/speech/src-val.txt -output pred.txt -gpu 0 -verbose
|
onmt_translate -data_type audio -model demo-model_acc_x_ppl_x_e13.pt -src_dir data/speech/an4_dataset -src data/speech/src-val.txt -output pred.txt -gpu 0 -verbose
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -158,19 +158,19 @@ Preprocess the data with
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
python preprocess.py -data_type vec -train_src $YT2T/yt2t_train_files.txt -src_dir $YT2T/r152/ -train_tgt $YT2T/yt2t_train_cap.txt -valid_src $YT2T/yt2t_val_files.txt -valid_tgt $YT2T/yt2t_val_cap.txt -save_data data/yt2t --shard_size 1000
|
onmt_preprocess -data_type vec -train_src $YT2T/yt2t_train_files.txt -src_dir $YT2T/r152/ -train_tgt $YT2T/yt2t_train_cap.txt -valid_src $YT2T/yt2t_val_files.txt -valid_tgt $YT2T/yt2t_val_cap.txt -save_data data/yt2t --shard_size 1000
|
||||||
|
|
||||||
Train with
|
Train with
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
python train.py -data data/yt2t -save_model yt2t-model -world_size 2 -gpu_ranks 0 1 -model_type vec -batch_size 64 -train_steps 10000 -valid_steps 500 -save_checkpoint_steps 500 -encoder_type brnn -optim adam -learning_rate .0001 -feat_vec_size 2048
|
onmt_train -data data/yt2t -save_model yt2t-model -world_size 2 -gpu_ranks 0 1 -model_type vec -batch_size 64 -train_steps 10000 -valid_steps 500 -save_checkpoint_steps 500 -encoder_type brnn -optim adam -learning_rate .0001 -feat_vec_size 2048
|
||||||
|
|
||||||
Translate with
|
Translate with
|
||||||
|
|
||||||
.. code-block::
|
.. code-block::
|
||||||
|
|
||||||
python translate.py -model yt2t-model_step_7200.pt -src $YT2T/yt2t_test_files.txt -output pred.txt -verbose -data_type vec -src_dir $YT2T/r152 -gpu 0 -batch_size 10
|
onmt_translate -model yt2t-model_step_7200.pt -src $YT2T/yt2t_test_files.txt -output pred.txt -verbose -data_type vec -src_dir $YT2T/r152 -gpu 0 -batch_size 10
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
|
@ -385,13 +385,13 @@ old data. Do this, i.e. ``rm data/yt2t.*.pt.``)
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
python preprocess.py -data_type vec -train_src $YT2T/yt2t_train_files.txt -src_dir $YT2T/r152/ -train_tgt $YT2T/yt2t_train_cap.txt -valid_src $YT2T/yt2t_val_files.txt -valid_tgt $YT2T/yt2t_val_cap.txt -save_data data/yt2t --shard_size 1000 --src_seq_length 50 --tgt_seq_length 20
|
onmt_preprocess -data_type vec -train_src $YT2T/yt2t_train_files.txt -src_dir $YT2T/r152/ -train_tgt $YT2T/yt2t_train_cap.txt -valid_src $YT2T/yt2t_val_files.txt -valid_tgt $YT2T/yt2t_val_cap.txt -save_data data/yt2t --shard_size 1000 --src_seq_length 50 --tgt_seq_length 20
|
||||||
|
|
||||||
Delete the old checkpoints and train a transformer model on this data.
|
Delete the old checkpoints and train a transformer model on this data.
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
rm -r yt2t-model_step_*.pt; python train.py -data data/yt2t -save_model yt2t-model -world_size 2 -gpu_ranks 0 1 -model_type vec -batch_size 64 -train_steps 8000 -valid_steps 400 -save_checkpoint_steps 400 -optim adam -learning_rate .0001 -feat_vec_size 2048 -layers 4 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 -encoder_type transformer -decoder_type transformer -position_encoding -dropout 0.3 -param_init 0 -param_init_glorot -report_every 400 --share_decoder_embedding --seed 7000
|
rm -r yt2t-model_step_*.pt; onmt_train -data data/yt2t -save_model yt2t-model -world_size 2 -gpu_ranks 0 1 -model_type vec -batch_size 64 -train_steps 8000 -valid_steps 400 -save_checkpoint_steps 400 -optim adam -learning_rate .0001 -feat_vec_size 2048 -layers 4 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 -encoder_type transformer -decoder_type transformer -position_encoding -dropout 0.3 -param_init 0 -param_init_glorot -report_every 400 --share_decoder_embedding --seed 7000
|
||||||
|
|
||||||
Note we use the hyperparameters described in the paper.
|
Note we use the hyperparameters described in the paper.
|
||||||
We estimate the length of 20 epochs with ``-train_steps``. Note that this depends on
|
We estimate the length of 20 epochs with ``-train_steps``. Note that this depends on
|
||||||
|
@ -502,7 +502,7 @@ although this too is not mentioned. You can reproduce our early-stops with these
|
||||||
for file in $( ls -1v yt2t-model_step*.pt )
|
for file in $( ls -1v yt2t-model_step*.pt )
|
||||||
do
|
do
|
||||||
echo $file
|
echo $file
|
||||||
python translate.py -model $file -src $YT2T/yt2t_val_folded_files.txt -output pred.txt -verbose -data_type vec -src_dir $YT2T/r152 -gpu 0 -batch_size 16 -max_length 20 >/dev/null 2>/dev/null
|
onmt_translate -model $file -src $YT2T/yt2t_val_folded_files.txt -output pred.txt -verbose -data_type vec -src_dir $YT2T/r152 -gpu 0 -batch_size 16 -max_length 20 >/dev/null 2>/dev/null
|
||||||
echo -e "$file\n" >> results.txt
|
echo -e "$file\n" >> results.txt
|
||||||
python coco.py -s val >> results.txt
|
python coco.py -s val >> results.txt
|
||||||
echo -e "\n\n" >> results.txt
|
echo -e "\n\n" >> results.txt
|
||||||
|
@ -519,7 +519,7 @@ although this too is not mentioned. You can reproduce our early-stops with these
|
||||||
metric=$(echo $line | awk '{print $1}')
|
metric=$(echo $line | awk '{print $1}')
|
||||||
step=$(echo $line | awk '{print $NF}')
|
step=$(echo $line | awk '{print $NF}')
|
||||||
echo $metric early stopped @ $step | tee -a test_results.txt
|
echo $metric early stopped @ $step | tee -a test_results.txt
|
||||||
python translate.py -model "yt2t-model_step_${step}.pt" -src $YT2T/yt2t_test_files.txt -output pred.txt -data_type vec -src_dir $YT2T/r152 -gpu 0 -batch_size 16 -max_length 20 >/dev/null 2>/dev/null
|
onmt_translate -model "yt2t-model_step_${step}.pt" -src $YT2T/yt2t_test_files.txt -output pred.txt -data_type vec -src_dir $YT2T/r152 -gpu 0 -batch_size 16 -max_length 20 >/dev/null 2>/dev/null
|
||||||
python coco.py -s 'test' >> test_results.txt
|
python coco.py -s 'test' >> test_results.txt
|
||||||
echo -e "\n\n" >> test_results.txt
|
echo -e "\n\n" >> test_results.txt
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -17,4 +17,4 @@ sys.modules["onmt.Optim"] = onmt.utils.optimizers
|
||||||
__all__ = [onmt.inputters, onmt.encoders, onmt.decoders, onmt.models,
|
__all__ = [onmt.inputters, onmt.encoders, onmt.decoders, onmt.models,
|
||||||
onmt.utils, onmt.modules, "Trainer"]
|
onmt.utils, onmt.modules, "Trainer"]
|
||||||
|
|
||||||
__version__ = "0.9.1"
|
__version__ = "1.1.1"
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def average_models(model_files, fp32=False):
|
||||||
|
vocab = None
|
||||||
|
opt = None
|
||||||
|
avg_model = None
|
||||||
|
avg_generator = None
|
||||||
|
|
||||||
|
for i, model_file in enumerate(model_files):
|
||||||
|
m = torch.load(model_file, map_location='cpu')
|
||||||
|
model_weights = m['model']
|
||||||
|
generator_weights = m['generator']
|
||||||
|
|
||||||
|
if fp32:
|
||||||
|
for k, v in model_weights.items():
|
||||||
|
model_weights[k] = v.float()
|
||||||
|
for k, v in generator_weights.items():
|
||||||
|
generator_weights[k] = v.float()
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
vocab, opt = m['vocab'], m['opt']
|
||||||
|
avg_model = model_weights
|
||||||
|
avg_generator = generator_weights
|
||||||
|
else:
|
||||||
|
for (k, v) in avg_model.items():
|
||||||
|
avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1)
|
||||||
|
|
||||||
|
for (k, v) in avg_generator.items():
|
||||||
|
avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1)
|
||||||
|
|
||||||
|
final = {"vocab": vocab, "opt": opt, "optim": None,
|
||||||
|
"generator": avg_generator, "model": avg_model}
|
||||||
|
return final
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="")
|
||||||
|
parser.add_argument("-models", "-m", nargs="+", required=True,
|
||||||
|
help="List of models")
|
||||||
|
parser.add_argument("-output", "-o", required=True,
|
||||||
|
help="Output file")
|
||||||
|
parser.add_argument("-fp32", "-f", action="store_true",
|
||||||
|
help="Cast params to float32")
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
final = average_models(opt.models, opt.fp32)
|
||||||
|
torch.save(final, opt.output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,315 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Pre-process Data / features files and build vocabulary
|
||||||
|
"""
|
||||||
|
import codecs
|
||||||
|
import glob
|
||||||
|
import gc
|
||||||
|
import torch
|
||||||
|
from collections import Counter, defaultdict
|
||||||
|
|
||||||
|
from onmt.utils.logging import init_logger, logger
|
||||||
|
from onmt.utils.misc import split_corpus
|
||||||
|
import onmt.inputters as inputters
|
||||||
|
import onmt.opts as opts
|
||||||
|
from onmt.utils.parse import ArgumentParser
|
||||||
|
from onmt.inputters.inputter import _build_fields_vocab,\
|
||||||
|
_load_vocab, \
|
||||||
|
old_style_vocab, \
|
||||||
|
load_old_vocab
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
|
||||||
|
def check_existing_pt_files(opt, corpus_type, ids, existing_fields):
|
||||||
|
""" Check if there are existing .pt files to avoid overwriting them """
|
||||||
|
existing_shards = []
|
||||||
|
for maybe_id in ids:
|
||||||
|
if maybe_id:
|
||||||
|
shard_base = corpus_type + "_" + maybe_id
|
||||||
|
else:
|
||||||
|
shard_base = corpus_type
|
||||||
|
pattern = opt.save_data + '.{}.*.pt'.format(shard_base)
|
||||||
|
if glob.glob(pattern):
|
||||||
|
if opt.overwrite:
|
||||||
|
maybe_overwrite = ("will be overwritten because "
|
||||||
|
"`-overwrite` option is set.")
|
||||||
|
else:
|
||||||
|
maybe_overwrite = ("won't be overwritten, pass the "
|
||||||
|
"`-overwrite` option if you want to.")
|
||||||
|
logger.warning("Shards for corpus {} already exist, {}"
|
||||||
|
.format(shard_base, maybe_overwrite))
|
||||||
|
existing_shards += [maybe_id]
|
||||||
|
return existing_shards
|
||||||
|
|
||||||
|
|
||||||
|
def process_one_shard(corpus_params, params):
|
||||||
|
corpus_type, fields, src_reader, tgt_reader, align_reader, opt,\
|
||||||
|
existing_fields, src_vocab, tgt_vocab = corpus_params
|
||||||
|
i, (src_shard, tgt_shard, align_shard, maybe_id, filter_pred) = params
|
||||||
|
# create one counter per shard
|
||||||
|
sub_sub_counter = defaultdict(Counter)
|
||||||
|
assert len(src_shard) == len(tgt_shard)
|
||||||
|
logger.info("Building shard %d." % i)
|
||||||
|
|
||||||
|
src_data = {"reader": src_reader, "data": src_shard, "dir": opt.src_dir}
|
||||||
|
tgt_data = {"reader": tgt_reader, "data": tgt_shard, "dir": None}
|
||||||
|
align_data = {"reader": align_reader, "data": align_shard, "dir": None}
|
||||||
|
_readers, _data, _dir = inputters.Dataset.config(
|
||||||
|
[('src', src_data), ('tgt', tgt_data), ('align', align_data)])
|
||||||
|
|
||||||
|
dataset = inputters.Dataset(
|
||||||
|
fields, readers=_readers, data=_data, dirs=_dir,
|
||||||
|
sort_key=inputters.str2sortkey[opt.data_type],
|
||||||
|
filter_pred=filter_pred,
|
||||||
|
corpus_id=maybe_id
|
||||||
|
)
|
||||||
|
if corpus_type == "train" and existing_fields is None:
|
||||||
|
for ex in dataset.examples:
|
||||||
|
sub_sub_counter['corpus_id'].update(
|
||||||
|
["train" if maybe_id is None else maybe_id])
|
||||||
|
for name, field in fields.items():
|
||||||
|
if ((opt.data_type == "audio") and (name == "src")):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
f_iter = iter(field)
|
||||||
|
except TypeError:
|
||||||
|
f_iter = [(name, field)]
|
||||||
|
all_data = [getattr(ex, name, None)]
|
||||||
|
else:
|
||||||
|
all_data = getattr(ex, name)
|
||||||
|
for (sub_n, sub_f), fd in zip(
|
||||||
|
f_iter, all_data):
|
||||||
|
has_vocab = (sub_n == 'src' and
|
||||||
|
src_vocab is not None) or \
|
||||||
|
(sub_n == 'tgt' and
|
||||||
|
tgt_vocab is not None)
|
||||||
|
if (hasattr(sub_f, 'sequential')
|
||||||
|
and sub_f.sequential and not has_vocab):
|
||||||
|
val = fd
|
||||||
|
sub_sub_counter[sub_n].update(val)
|
||||||
|
if maybe_id:
|
||||||
|
shard_base = corpus_type + "_" + maybe_id
|
||||||
|
else:
|
||||||
|
shard_base = corpus_type
|
||||||
|
data_path = "{:s}.{:s}.{:d}.pt".\
|
||||||
|
format(opt.save_data, shard_base, i)
|
||||||
|
|
||||||
|
logger.info(" * saving %sth %s data shard to %s."
|
||||||
|
% (i, shard_base, data_path))
|
||||||
|
|
||||||
|
dataset.save(data_path)
|
||||||
|
|
||||||
|
del dataset.examples
|
||||||
|
gc.collect()
|
||||||
|
del dataset
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return sub_sub_counter
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_load_vocab(corpus_type, counters, opt):
|
||||||
|
src_vocab = None
|
||||||
|
tgt_vocab = None
|
||||||
|
existing_fields = None
|
||||||
|
if corpus_type == "train":
|
||||||
|
if opt.src_vocab != "":
|
||||||
|
try:
|
||||||
|
logger.info("Using existing vocabulary...")
|
||||||
|
existing_fields = torch.load(opt.src_vocab)
|
||||||
|
except torch.serialization.pickle.UnpicklingError:
|
||||||
|
logger.info("Building vocab from text file...")
|
||||||
|
src_vocab, src_vocab_size = _load_vocab(
|
||||||
|
opt.src_vocab, "src", counters,
|
||||||
|
opt.src_words_min_frequency)
|
||||||
|
if opt.tgt_vocab != "":
|
||||||
|
tgt_vocab, tgt_vocab_size = _load_vocab(
|
||||||
|
opt.tgt_vocab, "tgt", counters,
|
||||||
|
opt.tgt_words_min_frequency)
|
||||||
|
return src_vocab, tgt_vocab, existing_fields
|
||||||
|
|
||||||
|
|
||||||
|
def build_save_dataset(corpus_type, fields, src_reader, tgt_reader,
|
||||||
|
align_reader, opt):
|
||||||
|
assert corpus_type in ['train', 'valid']
|
||||||
|
|
||||||
|
if corpus_type == 'train':
|
||||||
|
counters = defaultdict(Counter)
|
||||||
|
srcs = opt.train_src
|
||||||
|
tgts = opt.train_tgt
|
||||||
|
ids = opt.train_ids
|
||||||
|
aligns = opt.train_align
|
||||||
|
elif corpus_type == 'valid':
|
||||||
|
counters = None
|
||||||
|
srcs = [opt.valid_src]
|
||||||
|
tgts = [opt.valid_tgt]
|
||||||
|
ids = [None]
|
||||||
|
aligns = [opt.valid_align]
|
||||||
|
|
||||||
|
src_vocab, tgt_vocab, existing_fields = maybe_load_vocab(
|
||||||
|
corpus_type, counters, opt)
|
||||||
|
|
||||||
|
existing_shards = check_existing_pt_files(
|
||||||
|
opt, corpus_type, ids, existing_fields)
|
||||||
|
|
||||||
|
# every corpus has shards, no new one
|
||||||
|
if existing_shards == ids and not opt.overwrite:
|
||||||
|
return
|
||||||
|
|
||||||
|
def shard_iterator(srcs, tgts, ids, aligns, existing_shards,
|
||||||
|
existing_fields, corpus_type, opt):
|
||||||
|
"""
|
||||||
|
Builds a single iterator yielding every shard of every corpus.
|
||||||
|
"""
|
||||||
|
for src, tgt, maybe_id, maybe_align in zip(srcs, tgts, ids, aligns):
|
||||||
|
if maybe_id in existing_shards:
|
||||||
|
if opt.overwrite:
|
||||||
|
logger.warning("Overwrite shards for corpus {}"
|
||||||
|
.format(maybe_id))
|
||||||
|
else:
|
||||||
|
if corpus_type == "train":
|
||||||
|
assert existing_fields is not None,\
|
||||||
|
("A 'vocab.pt' file should be passed to "
|
||||||
|
"`-src_vocab` when adding a corpus to "
|
||||||
|
"a set of already existing shards.")
|
||||||
|
logger.warning("Ignore corpus {} because "
|
||||||
|
"shards already exist"
|
||||||
|
.format(maybe_id))
|
||||||
|
continue
|
||||||
|
if ((corpus_type == "train" or opt.filter_valid)
|
||||||
|
and tgt is not None):
|
||||||
|
filter_pred = partial(
|
||||||
|
inputters.filter_example,
|
||||||
|
use_src_len=opt.data_type == "text",
|
||||||
|
max_src_len=opt.src_seq_length,
|
||||||
|
max_tgt_len=opt.tgt_seq_length)
|
||||||
|
else:
|
||||||
|
filter_pred = None
|
||||||
|
src_shards = split_corpus(src, opt.shard_size)
|
||||||
|
tgt_shards = split_corpus(tgt, opt.shard_size)
|
||||||
|
align_shards = split_corpus(maybe_align, opt.shard_size)
|
||||||
|
for i, (ss, ts, a_s) in enumerate(
|
||||||
|
zip(src_shards, tgt_shards, align_shards)):
|
||||||
|
yield (i, (ss, ts, a_s, maybe_id, filter_pred))
|
||||||
|
|
||||||
|
shard_iter = shard_iterator(srcs, tgts, ids, aligns, existing_shards,
|
||||||
|
existing_fields, corpus_type, opt)
|
||||||
|
|
||||||
|
with Pool(opt.num_threads) as p:
|
||||||
|
dataset_params = (corpus_type, fields, src_reader, tgt_reader,
|
||||||
|
align_reader, opt, existing_fields,
|
||||||
|
src_vocab, tgt_vocab)
|
||||||
|
func = partial(process_one_shard, dataset_params)
|
||||||
|
for sub_counter in p.imap(func, shard_iter):
|
||||||
|
if sub_counter is not None:
|
||||||
|
for key, value in sub_counter.items():
|
||||||
|
counters[key].update(value)
|
||||||
|
|
||||||
|
if corpus_type == "train":
|
||||||
|
vocab_path = opt.save_data + '.vocab.pt'
|
||||||
|
new_fields = _build_fields_vocab(
|
||||||
|
fields, counters, opt.data_type,
|
||||||
|
opt.share_vocab, opt.vocab_size_multiple,
|
||||||
|
opt.src_vocab_size, opt.src_words_min_frequency,
|
||||||
|
opt.tgt_vocab_size, opt.tgt_words_min_frequency,
|
||||||
|
subword_prefix=opt.subword_prefix,
|
||||||
|
subword_prefix_is_joiner=opt.subword_prefix_is_joiner)
|
||||||
|
if existing_fields is None:
|
||||||
|
fields = new_fields
|
||||||
|
else:
|
||||||
|
fields = existing_fields
|
||||||
|
|
||||||
|
if old_style_vocab(fields):
|
||||||
|
fields = load_old_vocab(
|
||||||
|
fields, opt.data_type, dynamic_dict=opt.dynamic_dict)
|
||||||
|
|
||||||
|
# patch corpus_id
|
||||||
|
if fields.get("corpus_id", False):
|
||||||
|
fields["corpus_id"].vocab = new_fields["corpus_id"].vocab_cls(
|
||||||
|
counters["corpus_id"])
|
||||||
|
|
||||||
|
torch.save(fields, vocab_path)
|
||||||
|
|
||||||
|
|
||||||
|
def build_save_vocab(train_dataset, fields, opt):
|
||||||
|
fields = inputters.build_vocab(
|
||||||
|
train_dataset, fields, opt.data_type, opt.share_vocab,
|
||||||
|
opt.src_vocab, opt.src_vocab_size, opt.src_words_min_frequency,
|
||||||
|
opt.tgt_vocab, opt.tgt_vocab_size, opt.tgt_words_min_frequency,
|
||||||
|
vocab_size_multiple=opt.vocab_size_multiple
|
||||||
|
)
|
||||||
|
vocab_path = opt.save_data + '.vocab.pt'
|
||||||
|
torch.save(fields, vocab_path)
|
||||||
|
|
||||||
|
|
||||||
|
def count_features(path):
|
||||||
|
"""
|
||||||
|
path: location of a corpus file with whitespace-delimited tokens and
|
||||||
|
│-delimited features within the token
|
||||||
|
returns: the number of features in the dataset
|
||||||
|
"""
|
||||||
|
with codecs.open(path, "r", "utf-8") as f:
|
||||||
|
first_tok = f.readline().split(None, 1)[0]
|
||||||
|
return len(first_tok.split(u"│")) - 1
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(opt):
|
||||||
|
ArgumentParser.validate_preprocess_args(opt)
|
||||||
|
torch.manual_seed(opt.seed)
|
||||||
|
|
||||||
|
init_logger(opt.log_file)
|
||||||
|
|
||||||
|
logger.info("Extracting features...")
|
||||||
|
|
||||||
|
src_nfeats = 0
|
||||||
|
tgt_nfeats = 0
|
||||||
|
for src, tgt in zip(opt.train_src, opt.train_tgt):
|
||||||
|
src_nfeats += count_features(src) if opt.data_type == 'text' \
|
||||||
|
else 0
|
||||||
|
tgt_nfeats += count_features(tgt) # tgt always text so far
|
||||||
|
logger.info(" * number of source features: %d." % src_nfeats)
|
||||||
|
logger.info(" * number of target features: %d." % tgt_nfeats)
|
||||||
|
|
||||||
|
logger.info("Building `Fields` object...")
|
||||||
|
fields = inputters.get_fields(
|
||||||
|
opt.data_type,
|
||||||
|
src_nfeats,
|
||||||
|
tgt_nfeats,
|
||||||
|
dynamic_dict=opt.dynamic_dict,
|
||||||
|
with_align=opt.train_align[0] is not None,
|
||||||
|
src_truncate=opt.src_seq_length_trunc,
|
||||||
|
tgt_truncate=opt.tgt_seq_length_trunc)
|
||||||
|
|
||||||
|
src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
|
||||||
|
tgt_reader = inputters.str2reader["text"].from_opt(opt)
|
||||||
|
align_reader = inputters.str2reader["text"].from_opt(opt)
|
||||||
|
|
||||||
|
logger.info("Building & saving training data...")
|
||||||
|
build_save_dataset(
|
||||||
|
'train', fields, src_reader, tgt_reader, align_reader, opt)
|
||||||
|
|
||||||
|
if opt.valid_src and opt.valid_tgt:
|
||||||
|
logger.info("Building & saving validation data...")
|
||||||
|
build_save_dataset(
|
||||||
|
'valid', fields, src_reader, tgt_reader, align_reader, opt)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parser():
|
||||||
|
parser = ArgumentParser(description='preprocess.py')
|
||||||
|
|
||||||
|
opts.config_opts(parser)
|
||||||
|
opts.preprocess_opts(parser)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = _get_parser()
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
preprocess(opt)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,59 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_ctranslate2_model_spec(opt):
|
||||||
|
"""Creates a CTranslate2 model specification from the model options."""
|
||||||
|
with_relative_position = getattr(opt, "max_relative_positions", 0) > 0
|
||||||
|
is_ct2_compatible = (
|
||||||
|
opt.encoder_type == "transformer"
|
||||||
|
and opt.decoder_type == "transformer"
|
||||||
|
and getattr(opt, "self_attn_type", "scaled-dot") == "scaled-dot"
|
||||||
|
and ((opt.position_encoding and not with_relative_position)
|
||||||
|
or (with_relative_position and not opt.position_encoding)))
|
||||||
|
if not is_ct2_compatible:
|
||||||
|
return None
|
||||||
|
import ctranslate2
|
||||||
|
num_heads = getattr(opt, "heads", 8)
|
||||||
|
return ctranslate2.specs.TransformerSpec(
|
||||||
|
(opt.enc_layers, opt.dec_layers),
|
||||||
|
num_heads,
|
||||||
|
with_relative_position=with_relative_position)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Release an OpenNMT-py model for inference")
|
||||||
|
parser.add_argument("--model", "-m",
|
||||||
|
help="The model path", required=True)
|
||||||
|
parser.add_argument("--output", "-o",
|
||||||
|
help="The output path", required=True)
|
||||||
|
parser.add_argument("--format",
|
||||||
|
choices=["pytorch", "ctranslate2"],
|
||||||
|
default="pytorch",
|
||||||
|
help="The format of the released model")
|
||||||
|
parser.add_argument("--quantization", "-q",
|
||||||
|
choices=["int8", "int16"],
|
||||||
|
default=None,
|
||||||
|
help="Quantization type for CT2 model.")
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
model = torch.load(opt.model)
|
||||||
|
if opt.format == "pytorch":
|
||||||
|
model["optim"] = None
|
||||||
|
torch.save(model, opt.output)
|
||||||
|
elif opt.format == "ctranslate2":
|
||||||
|
model_spec = get_ctranslate2_model_spec(model["opt"])
|
||||||
|
if model_spec is None:
|
||||||
|
raise ValueError("This model is not supported by CTranslate2. Go "
|
||||||
|
"to https://github.com/OpenNMT/CTranslate2 for "
|
||||||
|
"more information on supported models.")
|
||||||
|
import ctranslate2
|
||||||
|
converter = ctranslate2.converters.OpenNMTPyConverter(opt.model)
|
||||||
|
converter.convert(opt.output, model_spec, force=True,
|
||||||
|
quantization=opt.quantization)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,157 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
import configargparse
|
||||||
|
|
||||||
|
from flask import Flask, jsonify, request
|
||||||
|
from waitress import serve
|
||||||
|
from onmt.translate import TranslationServer, ServerModelError
|
||||||
|
import logging
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
|
STATUS_OK = "ok"
|
||||||
|
STATUS_ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
def start(config_file,
|
||||||
|
url_root="./translator",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=5000,
|
||||||
|
debug=False):
|
||||||
|
def prefix_route(route_function, prefix='', mask='{0}{1}'):
|
||||||
|
def newroute(route, *args, **kwargs):
|
||||||
|
return route_function(mask.format(prefix, route), *args, **kwargs)
|
||||||
|
return newroute
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
logger = logging.getLogger("main")
|
||||||
|
log_format = logging.Formatter(
|
||||||
|
"[%(asctime)s %(levelname)s] %(message)s")
|
||||||
|
file_handler = RotatingFileHandler(
|
||||||
|
"debug_requests.log",
|
||||||
|
maxBytes=1000000, backupCount=10)
|
||||||
|
file_handler.setFormatter(log_format)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.route = prefix_route(app.route, url_root)
|
||||||
|
translation_server = TranslationServer()
|
||||||
|
translation_server.start(config_file)
|
||||||
|
|
||||||
|
@app.route('/models', methods=['GET'])
|
||||||
|
def get_models():
|
||||||
|
out = translation_server.list_models()
|
||||||
|
return jsonify(out)
|
||||||
|
|
||||||
|
@app.route('/health', methods=['GET'])
|
||||||
|
def health():
|
||||||
|
out = {}
|
||||||
|
out['status'] = STATUS_OK
|
||||||
|
return jsonify(out)
|
||||||
|
|
||||||
|
@app.route('/clone_model/<int:model_id>', methods=['POST'])
|
||||||
|
def clone_model(model_id):
|
||||||
|
out = {}
|
||||||
|
data = request.get_json(force=True)
|
||||||
|
timeout = -1
|
||||||
|
if 'timeout' in data:
|
||||||
|
timeout = data['timeout']
|
||||||
|
del data['timeout']
|
||||||
|
|
||||||
|
opt = data.get('opt', None)
|
||||||
|
try:
|
||||||
|
model_id, load_time = translation_server.clone_model(
|
||||||
|
model_id, opt, timeout)
|
||||||
|
except ServerModelError as e:
|
||||||
|
out['status'] = STATUS_ERROR
|
||||||
|
out['error'] = str(e)
|
||||||
|
else:
|
||||||
|
out['status'] = STATUS_OK
|
||||||
|
out['model_id'] = model_id
|
||||||
|
out['load_time'] = load_time
|
||||||
|
|
||||||
|
return jsonify(out)
|
||||||
|
|
||||||
|
@app.route('/unload_model/<int:model_id>', methods=['GET'])
|
||||||
|
def unload_model(model_id):
|
||||||
|
out = {"model_id": model_id}
|
||||||
|
|
||||||
|
try:
|
||||||
|
translation_server.unload_model(model_id)
|
||||||
|
out['status'] = STATUS_OK
|
||||||
|
except Exception as e:
|
||||||
|
out['status'] = STATUS_ERROR
|
||||||
|
out['error'] = str(e)
|
||||||
|
|
||||||
|
return jsonify(out)
|
||||||
|
|
||||||
|
@app.route('/translate', methods=['POST'])
|
||||||
|
def translate():
|
||||||
|
inputs = request.get_json(force=True)
|
||||||
|
if debug:
|
||||||
|
logger.info(inputs)
|
||||||
|
out = {}
|
||||||
|
try:
|
||||||
|
trans, scores, n_best, _, aligns = translation_server.run(inputs)
|
||||||
|
assert len(trans) == len(inputs) * n_best
|
||||||
|
assert len(scores) == len(inputs) * n_best
|
||||||
|
assert len(aligns) == len(inputs) * n_best
|
||||||
|
|
||||||
|
out = [[] for _ in range(n_best)]
|
||||||
|
for i in range(len(trans)):
|
||||||
|
response = {"src": inputs[i // n_best]['src'], "tgt": trans[i],
|
||||||
|
"n_best": n_best, "pred_score": scores[i]}
|
||||||
|
if aligns[i][0] is not None:
|
||||||
|
response["align"] = aligns[i]
|
||||||
|
out[i % n_best].append(response)
|
||||||
|
except ServerModelError as e:
|
||||||
|
model_id = inputs[0].get("id")
|
||||||
|
if debug:
|
||||||
|
logger.warning("Unload model #{} "
|
||||||
|
"because of an error".format(model_id))
|
||||||
|
translation_server.models[model_id].unload()
|
||||||
|
out['error'] = str(e)
|
||||||
|
out['status'] = STATUS_ERROR
|
||||||
|
if debug:
|
||||||
|
logger.info(out)
|
||||||
|
return jsonify(out)
|
||||||
|
|
||||||
|
@app.route('/to_cpu/<int:model_id>', methods=['GET'])
|
||||||
|
def to_cpu(model_id):
|
||||||
|
out = {'model_id': model_id}
|
||||||
|
translation_server.models[model_id].to_cpu()
|
||||||
|
|
||||||
|
out['status'] = STATUS_OK
|
||||||
|
return jsonify(out)
|
||||||
|
|
||||||
|
@app.route('/to_gpu/<int:model_id>', methods=['GET'])
|
||||||
|
def to_gpu(model_id):
|
||||||
|
out = {'model_id': model_id}
|
||||||
|
translation_server.models[model_id].to_gpu()
|
||||||
|
|
||||||
|
out['status'] = STATUS_OK
|
||||||
|
return jsonify(out)
|
||||||
|
|
||||||
|
serve(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parser():
|
||||||
|
parser = configargparse.ArgumentParser(
|
||||||
|
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
||||||
|
description="OpenNMT-py REST Server")
|
||||||
|
parser.add_argument("--ip", type=str, default="0.0.0.0")
|
||||||
|
parser.add_argument("--port", type=int, default="5000")
|
||||||
|
parser.add_argument("--url_root", type=str, default="/translator")
|
||||||
|
parser.add_argument("--debug", "-d", action="store_true")
|
||||||
|
parser.add_argument("--config", "-c", type=str,
|
||||||
|
default="./available_models/conf.json")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = _get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
start(args.config, url_root=args.url_root, host=args.ip, port=args.port,
|
||||||
|
debug=args.debug)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,213 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
"""Train models."""
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import onmt.opts as opts
|
||||||
|
import onmt.utils.distributed
|
||||||
|
|
||||||
|
from onmt.utils.misc import set_random_seed
|
||||||
|
from onmt.utils.logging import init_logger, logger
|
||||||
|
from onmt.train_single import main as single_main
|
||||||
|
from onmt.utils.parse import ArgumentParser
|
||||||
|
from onmt.inputters.inputter import build_dataset_iter, patch_fields, \
|
||||||
|
load_old_vocab, old_style_vocab, build_dataset_iter_multiple
|
||||||
|
|
||||||
|
from itertools import cycle
|
||||||
|
|
||||||
|
|
||||||
|
def train(opt):
|
||||||
|
ArgumentParser.validate_train_opts(opt)
|
||||||
|
ArgumentParser.update_model_opts(opt)
|
||||||
|
ArgumentParser.validate_model_opts(opt)
|
||||||
|
|
||||||
|
set_random_seed(opt.seed, False)
|
||||||
|
|
||||||
|
# Load checkpoint if we resume from a previous training.
|
||||||
|
if opt.train_from:
|
||||||
|
logger.info('Loading checkpoint from %s' % opt.train_from)
|
||||||
|
checkpoint = torch.load(opt.train_from,
|
||||||
|
map_location=lambda storage, loc: storage)
|
||||||
|
logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
|
||||||
|
vocab = checkpoint['vocab']
|
||||||
|
else:
|
||||||
|
vocab = torch.load(opt.data + '.vocab.pt')
|
||||||
|
|
||||||
|
# check for code where vocab is saved instead of fields
|
||||||
|
# (in the future this will be done in a smarter way)
|
||||||
|
if old_style_vocab(vocab):
|
||||||
|
fields = load_old_vocab(
|
||||||
|
vocab, opt.model_type, dynamic_dict=opt.copy_attn)
|
||||||
|
else:
|
||||||
|
fields = vocab
|
||||||
|
|
||||||
|
# patch for fields that may be missing in old data/model
|
||||||
|
patch_fields(opt, fields)
|
||||||
|
|
||||||
|
if len(opt.data_ids) > 1:
|
||||||
|
train_shards = []
|
||||||
|
for train_id in opt.data_ids:
|
||||||
|
shard_base = "train_" + train_id
|
||||||
|
train_shards.append(shard_base)
|
||||||
|
train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
|
||||||
|
else:
|
||||||
|
if opt.data_ids[0] is not None:
|
||||||
|
shard_base = "train_" + opt.data_ids[0]
|
||||||
|
else:
|
||||||
|
shard_base = "train"
|
||||||
|
train_iter = build_dataset_iter(shard_base, fields, opt)
|
||||||
|
|
||||||
|
nb_gpu = len(opt.gpu_ranks)
|
||||||
|
|
||||||
|
if opt.world_size > 1:
|
||||||
|
queues = []
|
||||||
|
mp = torch.multiprocessing.get_context('spawn')
|
||||||
|
semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
|
||||||
|
# Create a thread to listen for errors in the child processes.
|
||||||
|
error_queue = mp.SimpleQueue()
|
||||||
|
error_handler = ErrorHandler(error_queue)
|
||||||
|
# Train with multiprocessing.
|
||||||
|
procs = []
|
||||||
|
for device_id in range(nb_gpu):
|
||||||
|
q = mp.Queue(opt.queue_size)
|
||||||
|
queues += [q]
|
||||||
|
procs.append(mp.Process(target=run, args=(
|
||||||
|
opt, device_id, error_queue, q, semaphore), daemon=True))
|
||||||
|
procs[device_id].start()
|
||||||
|
logger.info(" Starting process pid: %d " % procs[device_id].pid)
|
||||||
|
error_handler.add_child(procs[device_id].pid)
|
||||||
|
producer = mp.Process(target=batch_producer,
|
||||||
|
args=(train_iter, queues, semaphore, opt,),
|
||||||
|
daemon=True)
|
||||||
|
producer.start()
|
||||||
|
error_handler.add_child(producer.pid)
|
||||||
|
|
||||||
|
for p in procs:
|
||||||
|
p.join()
|
||||||
|
producer.terminate()
|
||||||
|
|
||||||
|
elif nb_gpu == 1: # case 1 GPU only
|
||||||
|
single_main(opt, 0)
|
||||||
|
else: # case only CPU
|
||||||
|
single_main(opt, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_producer(generator_to_serve, queues, semaphore, opt):
|
||||||
|
init_logger(opt.log_file)
|
||||||
|
set_random_seed(opt.seed, False)
|
||||||
|
# generator_to_serve = iter(generator_to_serve)
|
||||||
|
|
||||||
|
def pred(x):
|
||||||
|
"""
|
||||||
|
Filters batches that belong only
|
||||||
|
to gpu_ranks of current node
|
||||||
|
"""
|
||||||
|
for rank in opt.gpu_ranks:
|
||||||
|
if x[0] % opt.world_size == rank:
|
||||||
|
return True
|
||||||
|
|
||||||
|
generator_to_serve = filter(
|
||||||
|
pred, enumerate(generator_to_serve))
|
||||||
|
|
||||||
|
def next_batch(device_id):
|
||||||
|
new_batch = next(generator_to_serve)
|
||||||
|
semaphore.acquire()
|
||||||
|
return new_batch[1]
|
||||||
|
|
||||||
|
b = next_batch(0)
|
||||||
|
|
||||||
|
for device_id, q in cycle(enumerate(queues)):
|
||||||
|
b.dataset = None
|
||||||
|
if isinstance(b.src, tuple):
|
||||||
|
b.src = tuple([_.to(torch.device(device_id))
|
||||||
|
for _ in b.src])
|
||||||
|
else:
|
||||||
|
b.src = b.src.to(torch.device(device_id))
|
||||||
|
b.tgt = b.tgt.to(torch.device(device_id))
|
||||||
|
b.indices = b.indices.to(torch.device(device_id))
|
||||||
|
b.alignment = b.alignment.to(torch.device(device_id)) \
|
||||||
|
if hasattr(b, 'alignment') else None
|
||||||
|
b.src_map = b.src_map.to(torch.device(device_id)) \
|
||||||
|
if hasattr(b, 'src_map') else None
|
||||||
|
b.align = b.align.to(torch.device(device_id)) \
|
||||||
|
if hasattr(b, 'align') else None
|
||||||
|
b.corpus_id = b.corpus_id.to(torch.device(device_id)) \
|
||||||
|
if hasattr(b, 'corpus_id') else None
|
||||||
|
|
||||||
|
# hack to dodge unpicklable `dict_keys`
|
||||||
|
b.fields = list(b.fields)
|
||||||
|
q.put(b)
|
||||||
|
b = next_batch(device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def run(opt, device_id, error_queue, batch_queue, semaphore):
|
||||||
|
""" run process """
|
||||||
|
try:
|
||||||
|
gpu_rank = onmt.utils.distributed.multi_init(opt, device_id)
|
||||||
|
if gpu_rank != opt.gpu_ranks[device_id]:
|
||||||
|
raise AssertionError("An error occurred in \
|
||||||
|
Distributed initialization")
|
||||||
|
single_main(opt, device_id, batch_queue, semaphore)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass # killed by parent, do nothing
|
||||||
|
except Exception:
|
||||||
|
# propagate exception to parent process, keeping original traceback
|
||||||
|
import traceback
|
||||||
|
error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc()))
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorHandler(object):
|
||||||
|
"""A class that listens for exceptions in children processes and propagates
|
||||||
|
the tracebacks to the parent process."""
|
||||||
|
|
||||||
|
def __init__(self, error_queue):
|
||||||
|
""" init error handler """
|
||||||
|
import signal
|
||||||
|
import threading
|
||||||
|
self.error_queue = error_queue
|
||||||
|
self.children_pids = []
|
||||||
|
self.error_thread = threading.Thread(
|
||||||
|
target=self.error_listener, daemon=True)
|
||||||
|
self.error_thread.start()
|
||||||
|
signal.signal(signal.SIGUSR1, self.signal_handler)
|
||||||
|
|
||||||
|
def add_child(self, pid):
|
||||||
|
""" error handler """
|
||||||
|
self.children_pids.append(pid)
|
||||||
|
|
||||||
|
def error_listener(self):
|
||||||
|
""" error listener """
|
||||||
|
(rank, original_trace) = self.error_queue.get()
|
||||||
|
self.error_queue.put((rank, original_trace))
|
||||||
|
os.kill(os.getpid(), signal.SIGUSR1)
|
||||||
|
|
||||||
|
def signal_handler(self, signalnum, stackframe):
|
||||||
|
""" signal handler """
|
||||||
|
for pid in self.children_pids:
|
||||||
|
os.kill(pid, signal.SIGINT) # kill children processes
|
||||||
|
(rank, original_trace) = self.error_queue.get()
|
||||||
|
msg = """\n\n-- Tracebacks above this line can probably
|
||||||
|
be ignored --\n\n"""
|
||||||
|
msg += original_trace
|
||||||
|
raise Exception(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parser():
|
||||||
|
parser = ArgumentParser(description='train.py')
|
||||||
|
|
||||||
|
opts.config_opts(parser)
|
||||||
|
opts.model_opts(parser)
|
||||||
|
opts.train_opts(parser)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = _get_parser()
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
train(opt)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,52 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from onmt.utils.logging import init_logger
|
||||||
|
from onmt.utils.misc import split_corpus
|
||||||
|
from onmt.translate.translator import build_translator
|
||||||
|
|
||||||
|
import onmt.opts as opts
|
||||||
|
from onmt.utils.parse import ArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def translate(opt):
|
||||||
|
ArgumentParser.validate_translate_opts(opt)
|
||||||
|
logger = init_logger(opt.log_file)
|
||||||
|
|
||||||
|
translator = build_translator(opt, report_score=True)
|
||||||
|
src_shards = split_corpus(opt.src, opt.shard_size)
|
||||||
|
tgt_shards = split_corpus(opt.tgt, opt.shard_size)
|
||||||
|
shard_pairs = zip(src_shards, tgt_shards)
|
||||||
|
|
||||||
|
for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
|
||||||
|
logger.info("Translating shard %d." % i)
|
||||||
|
translator.translate(
|
||||||
|
src=src_shard,
|
||||||
|
tgt=tgt_shard,
|
||||||
|
src_dir=opt.src_dir,
|
||||||
|
batch_size=opt.batch_size,
|
||||||
|
batch_type=opt.batch_type,
|
||||||
|
attn_debug=opt.attn_debug,
|
||||||
|
align_debug=opt.align_debug
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parser():
|
||||||
|
parser = ArgumentParser(description='translate.py')
|
||||||
|
|
||||||
|
opts.config_opts(parser)
|
||||||
|
opts.translate_opts(parser)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = _get_parser()
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
translate(opt)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -190,7 +190,8 @@ class RNNDecoderBase(DecoderBase):
|
||||||
self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"])
|
self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"])
|
||||||
self.state["input_feed"] = self.state["input_feed"].detach()
|
self.state["input_feed"] = self.state["input_feed"].detach()
|
||||||
|
|
||||||
def forward(self, tgt, memory_bank, memory_lengths=None, step=None):
|
def forward(self, tgt, memory_bank, memory_lengths=None, step=None,
|
||||||
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
tgt (LongTensor): sequences of padded tokens
|
tgt (LongTensor): sequences of padded tokens
|
||||||
|
|
|
@ -51,7 +51,8 @@ class EnsembleDecoder(DecoderBase):
|
||||||
super(EnsembleDecoder, self).__init__(attentional)
|
super(EnsembleDecoder, self).__init__(attentional)
|
||||||
self.model_decoders = model_decoders
|
self.model_decoders = model_decoders
|
||||||
|
|
||||||
def forward(self, tgt, memory_bank, memory_lengths=None, step=None):
|
def forward(self, tgt, memory_bank, memory_lengths=None, step=None,
|
||||||
|
**kwargs):
|
||||||
"""See :func:`onmt.decoders.decoder.DecoderBase.forward()`."""
|
"""See :func:`onmt.decoders.decoder.DecoderBase.forward()`."""
|
||||||
# Memory_lengths is a single tensor shared between all models.
|
# Memory_lengths is a single tensor shared between all models.
|
||||||
# This assumption will not hold if Translator is modified
|
# This assumption will not hold if Translator is modified
|
||||||
|
@ -60,7 +61,7 @@ class EnsembleDecoder(DecoderBase):
|
||||||
dec_outs, attns = zip(*[
|
dec_outs, attns = zip(*[
|
||||||
model_decoder(
|
model_decoder(
|
||||||
tgt, memory_bank[i],
|
tgt, memory_bank[i],
|
||||||
memory_lengths=memory_lengths, step=step)
|
memory_lengths=memory_lengths, step=step, **kwargs)
|
||||||
for i, model_decoder in enumerate(self.model_decoders)])
|
for i, model_decoder in enumerate(self.model_decoders)])
|
||||||
mean_attns = self.combine_attns(attns)
|
mean_attns = self.combine_attns(attns)
|
||||||
return EnsembleDecoderOutput(dec_outs), mean_attns
|
return EnsembleDecoderOutput(dec_outs), mean_attns
|
||||||
|
|
|
@ -12,25 +12,51 @@ from onmt.utils.misc import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(nn.Module):
|
class TransformerDecoderLayer(nn.Module):
|
||||||
"""
|
"""Transformer Decoder layer block in Pre-Norm style.
|
||||||
|
Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
|
||||||
|
providing better converge speed and performance. This is also the actual
|
||||||
|
implementation in tensor2tensor and also avalable in fairseq.
|
||||||
|
See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
|
||||||
|
|
||||||
|
.. mermaid::
|
||||||
|
|
||||||
|
graph LR
|
||||||
|
%% "*SubLayer" can be self-attn, src-attn or feed forward block
|
||||||
|
A(input) --> B[Norm]
|
||||||
|
B --> C["*SubLayer"]
|
||||||
|
C --> D[Drop]
|
||||||
|
D --> E((+))
|
||||||
|
A --> E
|
||||||
|
E --> F(out)
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
d_model (int): the dimension of keys/values/queries in
|
d_model (int): the dimension of keys/values/queries in
|
||||||
:class:`MultiHeadedAttention`, also the input size of
|
:class:`MultiHeadedAttention`, also the input size of
|
||||||
the first-layer of the :class:`PositionwiseFeedForward`.
|
the first-layer of the :class:`PositionwiseFeedForward`.
|
||||||
heads (int): the number of heads for MultiHeadedAttention.
|
heads (int): the number of heads for MultiHeadedAttention.
|
||||||
d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`.
|
d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`.
|
||||||
dropout (float): dropout probability.
|
dropout (float): dropout in residual, self-attn(dot) and feed-forward
|
||||||
|
attention_dropout (float): dropout in context_attn (and self-attn(avg))
|
||||||
self_attn_type (string): type of self-attention scaled-dot, average
|
self_attn_type (string): type of self-attention scaled-dot, average
|
||||||
|
max_relative_positions (int):
|
||||||
|
Max distance between inputs in relative positions representations
|
||||||
|
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
|
||||||
|
full_context_alignment (bool):
|
||||||
|
whether enable an extra full context decoder forward for alignment
|
||||||
|
alignment_heads (int):
|
||||||
|
N. of cross attention heads to use for alignment guiding
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d_model, heads, d_ff, dropout, attention_dropout,
|
def __init__(self, d_model, heads, d_ff, dropout, attention_dropout,
|
||||||
self_attn_type="scaled-dot", max_relative_positions=0,
|
self_attn_type="scaled-dot", max_relative_positions=0,
|
||||||
aan_useffn=False):
|
aan_useffn=False, full_context_alignment=False,
|
||||||
|
alignment_heads=0):
|
||||||
super(TransformerDecoderLayer, self).__init__()
|
super(TransformerDecoderLayer, self).__init__()
|
||||||
|
|
||||||
if self_attn_type == "scaled-dot":
|
if self_attn_type == "scaled-dot":
|
||||||
self.self_attn = MultiHeadedAttention(
|
self.self_attn = MultiHeadedAttention(
|
||||||
heads, d_model, dropout=dropout,
|
heads, d_model, dropout=attention_dropout,
|
||||||
max_relative_positions=max_relative_positions)
|
max_relative_positions=max_relative_positions)
|
||||||
elif self_attn_type == "average":
|
elif self_attn_type == "average":
|
||||||
self.self_attn = AverageAttention(d_model,
|
self.self_attn = AverageAttention(d_model,
|
||||||
|
@ -43,54 +69,106 @@ class TransformerDecoderLayer(nn.Module):
|
||||||
self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
|
self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
|
||||||
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
|
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
|
||||||
self.drop = nn.Dropout(dropout)
|
self.drop = nn.Dropout(dropout)
|
||||||
|
self.full_context_alignment = full_context_alignment
|
||||||
|
self.alignment_heads = alignment_heads
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
""" Extend `_forward` for (possibly) multiple decoder pass:
|
||||||
|
Always a default (future masked) decoder forward pass,
|
||||||
|
Possibly a second future aware decoder pass for joint learn
|
||||||
|
full context alignement, :cite:`garg2019jointly`.
|
||||||
|
|
||||||
def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
|
|
||||||
layer_cache=None, step=None):
|
|
||||||
"""
|
|
||||||
Args:
|
Args:
|
||||||
inputs (FloatTensor): ``(batch_size, 1, model_dim)``
|
* All arguments of _forward.
|
||||||
|
with_align (bool): whether return alignment attention.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(FloatTensor, FloatTensor, FloatTensor or None):
|
||||||
|
|
||||||
|
* output ``(batch_size, T, model_dim)``
|
||||||
|
* top_attn ``(batch_size, T, src_len)``
|
||||||
|
* attn_align ``(batch_size, T, src_len)`` or None
|
||||||
|
"""
|
||||||
|
with_align = kwargs.pop('with_align', False)
|
||||||
|
output, attns = self._forward(*args, **kwargs)
|
||||||
|
top_attn = attns[:, 0, :, :].contiguous()
|
||||||
|
attn_align = None
|
||||||
|
if with_align:
|
||||||
|
if self.full_context_alignment:
|
||||||
|
# return _, (B, Q_len, K_len)
|
||||||
|
_, attns = self._forward(*args, **kwargs, future=True)
|
||||||
|
|
||||||
|
if self.alignment_heads > 0:
|
||||||
|
attns = attns[:, :self.alignment_heads, :, :].contiguous()
|
||||||
|
# layer average attention across heads, get ``(B, Q, K)``
|
||||||
|
# Case 1: no full_context, no align heads -> layer avg baseline
|
||||||
|
# Case 2: no full_context, 1 align heads -> guided align
|
||||||
|
# Case 3: full_context, 1 align heads -> full cte guided align
|
||||||
|
attn_align = attns.mean(dim=1)
|
||||||
|
return output, top_attn, attn_align
|
||||||
|
|
||||||
|
def _forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
|
||||||
|
layer_cache=None, step=None, future=False):
|
||||||
|
""" A naive forward pass for transformer decoder.
|
||||||
|
|
||||||
|
# T: could be 1 in the case of stepwise decoding or tgt_len
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs (FloatTensor): ``(batch_size, T, model_dim)``
|
||||||
memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)``
|
memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)``
|
||||||
src_pad_mask (LongTensor): ``(batch_size, 1, src_len)``
|
src_pad_mask (LongTensor): ``(batch_size, 1, src_len)``
|
||||||
tgt_pad_mask (LongTensor): ``(batch_size, 1, 1)``
|
tgt_pad_mask (LongTensor): ``(batch_size, 1, T)``
|
||||||
|
layer_cache (dict or None): cached layer info when stepwise decode
|
||||||
|
step (int or None): stepwise decoding counter
|
||||||
|
future (bool): If set True, do not apply future_mask.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(FloatTensor, FloatTensor):
|
(FloatTensor, FloatTensor):
|
||||||
|
|
||||||
* output ``(batch_size, 1, model_dim)``
|
* output ``(batch_size, T, model_dim)``
|
||||||
* attn ``(batch_size, 1, src_len)``
|
* attns ``(batch_size, head, T, src_len)``
|
||||||
|
|
||||||
"""
|
"""
|
||||||
dec_mask = None
|
dec_mask = None
|
||||||
|
|
||||||
if step is None:
|
if step is None:
|
||||||
tgt_len = tgt_pad_mask.size(-1)
|
tgt_len = tgt_pad_mask.size(-1)
|
||||||
|
if not future: # apply future_mask, result mask in (B, T, T)
|
||||||
future_mask = torch.ones(
|
future_mask = torch.ones(
|
||||||
[tgt_len, tgt_len],
|
[tgt_len, tgt_len],
|
||||||
device=tgt_pad_mask.device,
|
device=tgt_pad_mask.device,
|
||||||
dtype=torch.uint8)
|
dtype=torch.uint8)
|
||||||
future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
|
future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
|
||||||
|
# BoolTensor was introduced in pytorch 1.2
|
||||||
|
try:
|
||||||
|
future_mask = future_mask.bool()
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
|
dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
|
||||||
|
else: # only mask padding, result mask in (B, 1, T)
|
||||||
|
dec_mask = tgt_pad_mask
|
||||||
|
|
||||||
input_norm = self.layer_norm_1(inputs)
|
input_norm = self.layer_norm_1(inputs)
|
||||||
|
|
||||||
if isinstance(self.self_attn, MultiHeadedAttention):
|
if isinstance(self.self_attn, MultiHeadedAttention):
|
||||||
query, attn = self.self_attn(input_norm, input_norm, input_norm,
|
query, _ = self.self_attn(input_norm, input_norm, input_norm,
|
||||||
mask=dec_mask,
|
mask=dec_mask,
|
||||||
layer_cache=layer_cache,
|
layer_cache=layer_cache,
|
||||||
attn_type="self")
|
attn_type="self")
|
||||||
elif isinstance(self.self_attn, AverageAttention):
|
elif isinstance(self.self_attn, AverageAttention):
|
||||||
query, attn = self.self_attn(input_norm, mask=dec_mask,
|
query, _ = self.self_attn(input_norm, mask=dec_mask,
|
||||||
layer_cache=layer_cache, step=step)
|
layer_cache=layer_cache, step=step)
|
||||||
|
|
||||||
query = self.drop(query) + inputs
|
query = self.drop(query) + inputs
|
||||||
|
|
||||||
query_norm = self.layer_norm_2(query)
|
query_norm = self.layer_norm_2(query)
|
||||||
mid, attn = self.context_attn(memory_bank, memory_bank, query_norm,
|
mid, attns = self.context_attn(memory_bank, memory_bank, query_norm,
|
||||||
mask=src_pad_mask,
|
mask=src_pad_mask,
|
||||||
layer_cache=layer_cache,
|
layer_cache=layer_cache,
|
||||||
attn_type="context")
|
attn_type="context")
|
||||||
output = self.feed_forward(self.drop(mid) + query)
|
output = self.feed_forward(self.drop(mid) + query)
|
||||||
|
|
||||||
return output, attn
|
return output, attns
|
||||||
|
|
||||||
def update_dropout(self, dropout, attention_dropout):
|
def update_dropout(self, dropout, attention_dropout):
|
||||||
self.self_attn.update_dropout(attention_dropout)
|
self.self_attn.update_dropout(attention_dropout)
|
||||||
|
@ -124,14 +202,25 @@ class TransformerDecoder(DecoderBase):
|
||||||
d_ff (int): size of the inner FF layer
|
d_ff (int): size of the inner FF layer
|
||||||
copy_attn (bool): if using a separate copy attention
|
copy_attn (bool): if using a separate copy attention
|
||||||
self_attn_type (str): type of self-attention scaled-dot, average
|
self_attn_type (str): type of self-attention scaled-dot, average
|
||||||
dropout (float): dropout parameters
|
dropout (float): dropout in residual, self-attn(dot) and feed-forward
|
||||||
|
attention_dropout (float): dropout in context_attn (and self-attn(avg))
|
||||||
embeddings (onmt.modules.Embeddings):
|
embeddings (onmt.modules.Embeddings):
|
||||||
embeddings to use, should have positional encodings
|
embeddings to use, should have positional encodings
|
||||||
|
max_relative_positions (int):
|
||||||
|
Max distance between inputs in relative positions representations
|
||||||
|
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
|
||||||
|
full_context_alignment (bool):
|
||||||
|
whether enable an extra full context decoder forward for alignment
|
||||||
|
alignment_layer (int): N° Layer to supervise with for alignment guiding
|
||||||
|
alignment_heads (int):
|
||||||
|
N. of cross attention heads to use for alignment guiding
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_layers, d_model, heads, d_ff,
|
def __init__(self, num_layers, d_model, heads, d_ff,
|
||||||
copy_attn, self_attn_type, dropout, attention_dropout,
|
copy_attn, self_attn_type, dropout, attention_dropout,
|
||||||
embeddings, max_relative_positions, aan_useffn):
|
embeddings, max_relative_positions, aan_useffn,
|
||||||
|
full_context_alignment, alignment_layer,
|
||||||
|
alignment_heads):
|
||||||
super(TransformerDecoder, self).__init__()
|
super(TransformerDecoder, self).__init__()
|
||||||
|
|
||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
|
@ -143,7 +232,9 @@ class TransformerDecoder(DecoderBase):
|
||||||
[TransformerDecoderLayer(d_model, heads, d_ff, dropout,
|
[TransformerDecoderLayer(d_model, heads, d_ff, dropout,
|
||||||
attention_dropout, self_attn_type=self_attn_type,
|
attention_dropout, self_attn_type=self_attn_type,
|
||||||
max_relative_positions=max_relative_positions,
|
max_relative_positions=max_relative_positions,
|
||||||
aan_useffn=aan_useffn)
|
aan_useffn=aan_useffn,
|
||||||
|
full_context_alignment=full_context_alignment,
|
||||||
|
alignment_heads=alignment_heads)
|
||||||
for i in range(num_layers)])
|
for i in range(num_layers)])
|
||||||
|
|
||||||
# previously, there was a GlobalAttention module here for copy
|
# previously, there was a GlobalAttention module here for copy
|
||||||
|
@ -152,6 +243,8 @@ class TransformerDecoder(DecoderBase):
|
||||||
self._copy = copy_attn
|
self._copy = copy_attn
|
||||||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||||
|
|
||||||
|
self.alignment_layer = alignment_layer
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_opt(cls, opt, embeddings):
|
def from_opt(cls, opt, embeddings):
|
||||||
"""Alternate constructor."""
|
"""Alternate constructor."""
|
||||||
|
@ -167,7 +260,10 @@ class TransformerDecoder(DecoderBase):
|
||||||
is list else opt.dropout,
|
is list else opt.dropout,
|
||||||
embeddings,
|
embeddings,
|
||||||
opt.max_relative_positions,
|
opt.max_relative_positions,
|
||||||
opt.aan_useffn)
|
opt.aan_useffn,
|
||||||
|
opt.full_context_alignment,
|
||||||
|
opt.alignment_layer,
|
||||||
|
alignment_heads=opt.alignment_heads)
|
||||||
|
|
||||||
def init_state(self, src, memory_bank, enc_hidden):
|
def init_state(self, src, memory_bank, enc_hidden):
|
||||||
"""Initialize decoder state."""
|
"""Initialize decoder state."""
|
||||||
|
@ -209,16 +305,22 @@ class TransformerDecoder(DecoderBase):
|
||||||
src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
|
src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
|
||||||
tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]
|
tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]
|
||||||
|
|
||||||
|
with_align = kwargs.pop('with_align', False)
|
||||||
|
attn_aligns = []
|
||||||
|
|
||||||
for i, layer in enumerate(self.transformer_layers):
|
for i, layer in enumerate(self.transformer_layers):
|
||||||
layer_cache = self.state["cache"]["layer_{}".format(i)] \
|
layer_cache = self.state["cache"]["layer_{}".format(i)] \
|
||||||
if step is not None else None
|
if step is not None else None
|
||||||
output, attn = layer(
|
output, attn, attn_align = layer(
|
||||||
output,
|
output,
|
||||||
src_memory_bank,
|
src_memory_bank,
|
||||||
src_pad_mask,
|
src_pad_mask,
|
||||||
tgt_pad_mask,
|
tgt_pad_mask,
|
||||||
layer_cache=layer_cache,
|
layer_cache=layer_cache,
|
||||||
step=step)
|
step=step,
|
||||||
|
with_align=with_align)
|
||||||
|
if attn_align is not None:
|
||||||
|
attn_aligns.append(attn_align)
|
||||||
|
|
||||||
output = self.layer_norm(output)
|
output = self.layer_norm(output)
|
||||||
dec_outs = output.transpose(0, 1).contiguous()
|
dec_outs = output.transpose(0, 1).contiguous()
|
||||||
|
@ -227,6 +329,9 @@ class TransformerDecoder(DecoderBase):
|
||||||
attns = {"std": attn}
|
attns = {"std": attn}
|
||||||
if self._copy:
|
if self._copy:
|
||||||
attns["copy"] = attn
|
attns["copy"] = attn
|
||||||
|
if with_align:
|
||||||
|
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
|
||||||
|
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
|
||||||
|
|
||||||
# TODO change the way attns is returned dict => list or tuple (onnx)
|
# TODO change the way attns is returned dict => list or tuple (onnx)
|
||||||
return dec_outs, attns
|
return dec_outs, attns
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""Module defining encoders."""
|
"""Module defining encoders."""
|
||||||
from onmt.encoders.encoder import EncoderBase
|
from onmt.encoders.encoder import EncoderBase
|
||||||
from onmt.encoders.transformer import TransformerEncoder
|
from onmt.encoders.transformer import TransformerEncoder
|
||||||
|
from onmt.encoders.ggnn_encoder import GGNNEncoder
|
||||||
from onmt.encoders.rnn_encoder import RNNEncoder
|
from onmt.encoders.rnn_encoder import RNNEncoder
|
||||||
from onmt.encoders.cnn_encoder import CNNEncoder
|
from onmt.encoders.cnn_encoder import CNNEncoder
|
||||||
from onmt.encoders.mean_encoder import MeanEncoder
|
from onmt.encoders.mean_encoder import MeanEncoder
|
||||||
|
@ -8,9 +9,9 @@ from onmt.encoders.audio_encoder import AudioEncoder
|
||||||
from onmt.encoders.image_encoder import ImageEncoder
|
from onmt.encoders.image_encoder import ImageEncoder
|
||||||
|
|
||||||
|
|
||||||
str2enc = {"rnn": RNNEncoder, "brnn": RNNEncoder, "cnn": CNNEncoder,
|
str2enc = {"ggnn": GGNNEncoder, "rnn": RNNEncoder, "brnn": RNNEncoder,
|
||||||
"transformer": TransformerEncoder, "img": ImageEncoder,
|
"cnn": CNNEncoder, "transformer": TransformerEncoder,
|
||||||
"audio": AudioEncoder, "mean": MeanEncoder}
|
"img": ImageEncoder, "audio": AudioEncoder, "mean": MeanEncoder}
|
||||||
|
|
||||||
__all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder",
|
__all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder",
|
||||||
"MeanEncoder", "str2enc"]
|
"MeanEncoder", "str2enc"]
|
||||||
|
|
|
@ -0,0 +1,279 @@
|
||||||
|
"""Define GGNN-based encoders."""
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from onmt.encoders.encoder import EncoderBase
|
||||||
|
|
||||||
|
|
||||||
|
class GGNNAttrProxy(object):
|
||||||
|
"""
|
||||||
|
Translates index lookups into attribute lookups.
|
||||||
|
To implement some trick which able to use list of nn.Module in a nn.Module
|
||||||
|
see https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219/2
|
||||||
|
"""
|
||||||
|
def __init__(self, module, prefix):
|
||||||
|
self.module = module
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return getattr(self.module, self.prefix + str(i))
|
||||||
|
|
||||||
|
|
||||||
|
class GGNNPropogator(nn.Module):
|
||||||
|
"""
|
||||||
|
Gated Propogator for GGNN
|
||||||
|
Using LSTM gating mechanism
|
||||||
|
"""
|
||||||
|
def __init__(self, state_dim, n_node, n_edge_types):
|
||||||
|
super(GGNNPropogator, self).__init__()
|
||||||
|
|
||||||
|
self.n_node = n_node
|
||||||
|
self.n_edge_types = n_edge_types
|
||||||
|
|
||||||
|
self.reset_gate = nn.Sequential(
|
||||||
|
nn.Linear(state_dim*3, state_dim),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
self.update_gate = nn.Sequential(
|
||||||
|
nn.Linear(state_dim*3, state_dim),
|
||||||
|
nn.Sigmoid()
|
||||||
|
)
|
||||||
|
self.tansform = nn.Sequential(
|
||||||
|
nn.Linear(state_dim*3, state_dim),
|
||||||
|
nn.LeakyReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, state_in, state_out, state_cur, edges, nodes):
|
||||||
|
edges_in = edges[:, :, :nodes*self.n_edge_types]
|
||||||
|
edges_out = edges[:, :, nodes*self.n_edge_types:]
|
||||||
|
|
||||||
|
a_in = torch.bmm(edges_in, state_in)
|
||||||
|
a_out = torch.bmm(edges_out, state_out)
|
||||||
|
a = torch.cat((a_in, a_out, state_cur), 2)
|
||||||
|
|
||||||
|
r = self.reset_gate(a)
|
||||||
|
z = self.update_gate(a)
|
||||||
|
joined_input = torch.cat((a_in, a_out, r * state_cur), 2)
|
||||||
|
h_hat = self.tansform(joined_input)
|
||||||
|
|
||||||
|
output = (1 - z) * state_cur + z * h_hat
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class GGNNEncoder(EncoderBase):
|
||||||
|
""" A gated graph neural network configured as an encoder.
|
||||||
|
Based on github.com/JamesChuanggg/ggnn.pytorch.git,
|
||||||
|
which is based on the paper "Gated Graph Sequence Neural Networks"
|
||||||
|
by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rnn_type (str):
|
||||||
|
style of recurrent unit to use, one of [LSTM]
|
||||||
|
state_dim (int) : Number of state dimensions in nodes
|
||||||
|
n_edge_types (int) : Number of edge types
|
||||||
|
bidir_edges (bool): True if reverse edges should be autocreated
|
||||||
|
n_node (int) : Max nodes in graph
|
||||||
|
bridge_extra_node (bool): True indicates only 1st extra node
|
||||||
|
(after token listing) should be used for decoder init.
|
||||||
|
n_steps (int): Steps to advance graph encoder for stabilization
|
||||||
|
src_vocab (int): Path to source vocabulary
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, rnn_type, state_dim, bidir_edges,
|
||||||
|
n_edge_types, n_node, bridge_extra_node, n_steps, src_vocab):
|
||||||
|
super(GGNNEncoder, self).__init__()
|
||||||
|
|
||||||
|
self.state_dim = state_dim
|
||||||
|
self.n_edge_types = n_edge_types
|
||||||
|
self.n_node = n_node
|
||||||
|
self.n_steps = n_steps
|
||||||
|
self.bidir_edges = bidir_edges
|
||||||
|
self.bridge_extra_node = bridge_extra_node
|
||||||
|
|
||||||
|
for i in range(self.n_edge_types):
|
||||||
|
# incoming and outgoing edge embedding
|
||||||
|
in_fc = nn.Linear(self.state_dim, self.state_dim)
|
||||||
|
out_fc = nn.Linear(self.state_dim, self.state_dim)
|
||||||
|
self.add_module("in_{}".format(i), in_fc)
|
||||||
|
self.add_module("out_{}".format(i), out_fc)
|
||||||
|
|
||||||
|
self.in_fcs = GGNNAttrProxy(self, "in_")
|
||||||
|
self.out_fcs = GGNNAttrProxy(self, "out_")
|
||||||
|
|
||||||
|
# Find vocab data for tree builting
|
||||||
|
f = open(src_vocab, "r")
|
||||||
|
idx = 0
|
||||||
|
self.COMMA = -1
|
||||||
|
self.DELIMITER = -1
|
||||||
|
self.idx2num = []
|
||||||
|
for ln in f:
|
||||||
|
ln = ln.strip('\n')
|
||||||
|
if ln == ",":
|
||||||
|
self.COMMA = idx
|
||||||
|
if ln == "<EOT>":
|
||||||
|
self.DELIMITER = idx
|
||||||
|
if ln.isdigit():
|
||||||
|
self.idx2num.append(int(ln))
|
||||||
|
else:
|
||||||
|
self.idx2num.append(-1)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
# Propogation Model
|
||||||
|
self.propogator = GGNNPropogator(self.state_dim, self.n_node,
|
||||||
|
self.n_edge_types)
|
||||||
|
|
||||||
|
self._initialization()
|
||||||
|
|
||||||
|
# Initialize the bridge layer
|
||||||
|
self._initialize_bridge(rnn_type, self.state_dim, 1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_opt(cls, opt, embeddings):
|
||||||
|
"""Alternate constructor."""
|
||||||
|
return cls(
|
||||||
|
opt.rnn_type,
|
||||||
|
opt.state_dim,
|
||||||
|
opt.bidir_edges,
|
||||||
|
opt.n_edge_types,
|
||||||
|
opt.n_node,
|
||||||
|
opt.bridge_extra_node,
|
||||||
|
opt.n_steps,
|
||||||
|
opt.src_vocab)
|
||||||
|
|
||||||
|
def _initialization(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
m.weight.data.normal_(0.0, 0.02)
|
||||||
|
m.bias.data.fill_(0)
|
||||||
|
|
||||||
|
def forward(self, src, lengths=None):
|
||||||
|
"""See :func:`EncoderBase.forward()`"""
|
||||||
|
self._check_args(src, lengths)
|
||||||
|
nodes = self.n_node
|
||||||
|
batch_size = src.size()[1]
|
||||||
|
first_extra = np.zeros(batch_size, dtype=np.int32)
|
||||||
|
prop_state = np.zeros((batch_size, nodes, self.state_dim),
|
||||||
|
dtype=np.int32)
|
||||||
|
edges = np.zeros((batch_size, nodes, nodes*self.n_edge_types*2),
|
||||||
|
dtype=np.int32)
|
||||||
|
npsrc = src[:, :, 0].cpu().data.numpy().astype(np.int32)
|
||||||
|
|
||||||
|
# Initialize graph using formatted input sequence
|
||||||
|
for i in range(batch_size):
|
||||||
|
tokens_done = False
|
||||||
|
# Number of flagged nodes defines node count for this sample
|
||||||
|
# (Nodes can have no flags on them, but must be in 'flags' list).
|
||||||
|
flags = 0
|
||||||
|
flags_done = False
|
||||||
|
edge = 0
|
||||||
|
source_node = -1
|
||||||
|
for j in range(len(npsrc)):
|
||||||
|
token = npsrc[j][i]
|
||||||
|
if not tokens_done:
|
||||||
|
if token == self.DELIMITER:
|
||||||
|
tokens_done = True
|
||||||
|
first_extra[i] = j
|
||||||
|
else:
|
||||||
|
prop_state[i][j][token] = 1
|
||||||
|
elif token == self.DELIMITER:
|
||||||
|
flags += 1
|
||||||
|
flags_done = True
|
||||||
|
assert flags <= nodes
|
||||||
|
elif not flags_done:
|
||||||
|
# The total number of integers in the vocab should allow
|
||||||
|
# for all features and edges to be defined.
|
||||||
|
if token == self.COMMA:
|
||||||
|
flags = 0
|
||||||
|
else:
|
||||||
|
num = self.idx2num[token]
|
||||||
|
if num >= 0:
|
||||||
|
prop_state[i][flags][num+self.DELIMITER] = 1
|
||||||
|
flags += 1
|
||||||
|
elif token == self.COMMA:
|
||||||
|
edge += 1
|
||||||
|
assert source_node == -1, 'Error in graph edge input'
|
||||||
|
assert (edge <= 2*self.n_edge_types and
|
||||||
|
(not self.bidir_edges or edge < self.n_edge_types))
|
||||||
|
else:
|
||||||
|
num = self.idx2num[token]
|
||||||
|
if source_node < 0:
|
||||||
|
source_node = num
|
||||||
|
else:
|
||||||
|
edges[i][source_node][num+nodes*edge] = 1
|
||||||
|
if self.bidir_edges:
|
||||||
|
edges[i][num][nodes*(edge+self.n_edge_types)
|
||||||
|
+ source_node] = 1
|
||||||
|
source_node = -1
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
prop_state = torch.from_numpy(prop_state).float().to("cuda:0")
|
||||||
|
edges = torch.from_numpy(edges).float().to("cuda:0")
|
||||||
|
else:
|
||||||
|
prop_state = torch.from_numpy(prop_state).float()
|
||||||
|
edges = torch.from_numpy(edges).float()
|
||||||
|
|
||||||
|
for i_step in range(self.n_steps):
|
||||||
|
in_states = []
|
||||||
|
out_states = []
|
||||||
|
for i in range(self.n_edge_types):
|
||||||
|
in_states.append(self.in_fcs[i](prop_state))
|
||||||
|
out_states.append(self.out_fcs[i](prop_state))
|
||||||
|
in_states = torch.stack(in_states).transpose(0, 1).contiguous()
|
||||||
|
in_states = in_states.view(-1, nodes*self.n_edge_types,
|
||||||
|
self.state_dim)
|
||||||
|
out_states = torch.stack(out_states).transpose(0, 1).contiguous()
|
||||||
|
out_states = out_states.view(-1, nodes*self.n_edge_types,
|
||||||
|
self.state_dim)
|
||||||
|
|
||||||
|
prop_state = self.propogator(in_states, out_states, prop_state,
|
||||||
|
edges, nodes)
|
||||||
|
|
||||||
|
prop_state = prop_state.transpose(0, 1)
|
||||||
|
if self.bridge_extra_node:
|
||||||
|
# Use first extra node as only source for decoder init
|
||||||
|
join_state = prop_state[first_extra, torch.arange(batch_size)]
|
||||||
|
else:
|
||||||
|
# Average all nodes to get bridge input
|
||||||
|
join_state = prop_state.mean(0)
|
||||||
|
join_state = torch.stack((join_state, join_state,
|
||||||
|
join_state, join_state))
|
||||||
|
join_state = (join_state, join_state)
|
||||||
|
|
||||||
|
encoder_final = self._bridge(join_state)
|
||||||
|
|
||||||
|
return encoder_final, prop_state, lengths
|
||||||
|
|
||||||
|
def _initialize_bridge(self, rnn_type,
|
||||||
|
hidden_size,
|
||||||
|
num_layers):
|
||||||
|
|
||||||
|
# LSTM has hidden and cell state, other only one
|
||||||
|
number_of_states = 2 if rnn_type == "LSTM" else 1
|
||||||
|
# Total number of states
|
||||||
|
self.total_hidden_dim = hidden_size * num_layers
|
||||||
|
|
||||||
|
# Build a linear layer for each
|
||||||
|
self.bridge = nn.ModuleList([nn.Linear(self.total_hidden_dim,
|
||||||
|
self.total_hidden_dim,
|
||||||
|
bias=True)
|
||||||
|
for _ in range(number_of_states)])
|
||||||
|
|
||||||
|
def _bridge(self, hidden):
|
||||||
|
"""Forward hidden state through bridge."""
|
||||||
|
def bottle_hidden(linear, states):
|
||||||
|
"""
|
||||||
|
Transform from 3D to 2D, apply linear and return initial size
|
||||||
|
"""
|
||||||
|
size = states.size()
|
||||||
|
result = linear(states.view(-1, self.total_hidden_dim))
|
||||||
|
return F.leaky_relu(result).view(size)
|
||||||
|
|
||||||
|
if isinstance(hidden, tuple): # LSTM
|
||||||
|
outs = tuple([bottle_hidden(layer, hidden[ix])
|
||||||
|
for ix, layer in enumerate(self.bridge)])
|
||||||
|
else:
|
||||||
|
outs = bottle_hidden(self.bridge[0], hidden)
|
||||||
|
return outs
|
|
@ -13,7 +13,6 @@ from onmt.inputters.audio_dataset import audio_sort_key, AudioDataReader
|
||||||
from onmt.inputters.vec_dataset import vec_sort_key, VecDataReader
|
from onmt.inputters.vec_dataset import vec_sort_key, VecDataReader
|
||||||
from onmt.inputters.datareader_base import DataReaderBase
|
from onmt.inputters.datareader_base import DataReaderBase
|
||||||
|
|
||||||
|
|
||||||
str2reader = {
|
str2reader = {
|
||||||
"text": TextDataReader, "img": ImageDataReader, "audio": AudioDataReader,
|
"text": TextDataReader, "img": ImageDataReader, "audio": AudioDataReader,
|
||||||
"vec": VecDataReader}
|
"vec": VecDataReader}
|
||||||
|
|
|
@ -108,7 +108,7 @@ class Dataset(TorchtextDataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, fields, readers, data, dirs, sort_key,
|
def __init__(self, fields, readers, data, dirs, sort_key,
|
||||||
filter_pred=None):
|
filter_pred=None, corpus_id=None):
|
||||||
self.sort_key = sort_key
|
self.sort_key = sort_key
|
||||||
can_copy = 'src_map' in fields and 'alignment' in fields
|
can_copy = 'src_map' in fields and 'alignment' in fields
|
||||||
|
|
||||||
|
@ -119,6 +119,10 @@ class Dataset(TorchtextDataset):
|
||||||
self.src_vocabs = []
|
self.src_vocabs = []
|
||||||
examples = []
|
examples = []
|
||||||
for ex_dict in starmap(_join_dicts, zip(*read_iters)):
|
for ex_dict in starmap(_join_dicts, zip(*read_iters)):
|
||||||
|
if corpus_id is not None:
|
||||||
|
ex_dict["corpus_id"] = corpus_id
|
||||||
|
else:
|
||||||
|
ex_dict["corpus_id"] = "train"
|
||||||
if can_copy:
|
if can_copy:
|
||||||
src_field = fields['src']
|
src_field = fields['src']
|
||||||
tgt_field = fields['tgt']
|
tgt_field = fields['tgt']
|
||||||
|
@ -152,3 +156,13 @@ class Dataset(TorchtextDataset):
|
||||||
if remove_fields:
|
if remove_fields:
|
||||||
self.fields = []
|
self.fields = []
|
||||||
torch.save(self, path)
|
torch.save(self, path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def config(fields):
|
||||||
|
readers, data, dirs = [], [], []
|
||||||
|
for name, field in fields:
|
||||||
|
if field["data"] is not None:
|
||||||
|
readers.append(field["reader"])
|
||||||
|
data.append((name, field["data"]))
|
||||||
|
dirs.append(field["dir"])
|
||||||
|
return readers, data, dirs
|
||||||
|
|
|
@ -9,7 +9,7 @@ from itertools import chain, cycle
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchtext.data
|
import torchtext.data
|
||||||
from torchtext.data import Field, RawField
|
from torchtext.data import Field, RawField, LabelField
|
||||||
from torchtext.vocab import Vocab
|
from torchtext.vocab import Vocab
|
||||||
from torchtext.data.utils import RandomShuffler
|
from torchtext.data.utils import RandomShuffler
|
||||||
|
|
||||||
|
@ -58,6 +58,47 @@ def make_tgt(data, vocab):
|
||||||
return alignment
|
return alignment
|
||||||
|
|
||||||
|
|
||||||
|
class AlignField(LabelField):
|
||||||
|
"""
|
||||||
|
Parse ['<src>-<tgt>', ...] into ['<src>','<tgt>', ...]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs['use_vocab'] = False
|
||||||
|
kwargs['preprocessing'] = parse_align_idx
|
||||||
|
super(AlignField, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def process(self, batch, device=None):
|
||||||
|
""" Turn a batch of align-idx to a sparse align idx Tensor"""
|
||||||
|
sparse_idx = []
|
||||||
|
for i, example in enumerate(batch):
|
||||||
|
for src, tgt in example:
|
||||||
|
# +1 for tgt side to keep coherent after "bos" padding,
|
||||||
|
# register ['N°_in_batch', 'tgt_id+1', 'src_id']
|
||||||
|
sparse_idx.append([i, tgt + 1, src])
|
||||||
|
|
||||||
|
align_idx = torch.tensor(sparse_idx, dtype=self.dtype, device=device)
|
||||||
|
|
||||||
|
return align_idx
|
||||||
|
|
||||||
|
|
||||||
|
def parse_align_idx(align_pharaoh):
|
||||||
|
"""
|
||||||
|
Parse Pharaoh alignment into [[<src>, <tgt>], ...]
|
||||||
|
"""
|
||||||
|
align_list = align_pharaoh.strip().split(' ')
|
||||||
|
flatten_align_idx = []
|
||||||
|
for align in align_list:
|
||||||
|
try:
|
||||||
|
src_idx, tgt_idx = align.split('-')
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("{} in `{}`".format(align, align_pharaoh))
|
||||||
|
logger.warning("Bad alignement line exists. Please check file!")
|
||||||
|
raise
|
||||||
|
flatten_align_idx.append([int(src_idx), int(tgt_idx)])
|
||||||
|
return flatten_align_idx
|
||||||
|
|
||||||
|
|
||||||
def get_fields(
|
def get_fields(
|
||||||
src_data_type,
|
src_data_type,
|
||||||
n_src_feats,
|
n_src_feats,
|
||||||
|
@ -66,6 +107,7 @@ def get_fields(
|
||||||
bos='<s>',
|
bos='<s>',
|
||||||
eos='</s>',
|
eos='</s>',
|
||||||
dynamic_dict=False,
|
dynamic_dict=False,
|
||||||
|
with_align=False,
|
||||||
src_truncate=None,
|
src_truncate=None,
|
||||||
tgt_truncate=None
|
tgt_truncate=None
|
||||||
):
|
):
|
||||||
|
@ -84,6 +126,7 @@ def get_fields(
|
||||||
for tgt.
|
for tgt.
|
||||||
dynamic_dict (bool): Whether or not to include source map and
|
dynamic_dict (bool): Whether or not to include source map and
|
||||||
alignment fields.
|
alignment fields.
|
||||||
|
with_align (bool): Whether or not to include word align.
|
||||||
src_truncate: Cut off src sequences beyond this (passed to
|
src_truncate: Cut off src sequences beyond this (passed to
|
||||||
``src_data_type``'s data reader - see there for more details).
|
``src_data_type``'s data reader - see there for more details).
|
||||||
tgt_truncate: Cut off tgt sequences beyond this (passed to
|
tgt_truncate: Cut off tgt sequences beyond this (passed to
|
||||||
|
@ -122,6 +165,9 @@ def get_fields(
|
||||||
indices = Field(use_vocab=False, dtype=torch.long, sequential=False)
|
indices = Field(use_vocab=False, dtype=torch.long, sequential=False)
|
||||||
fields["indices"] = indices
|
fields["indices"] = indices
|
||||||
|
|
||||||
|
corpus_ids = Field(use_vocab=True, sequential=False)
|
||||||
|
fields["corpus_id"] = corpus_ids
|
||||||
|
|
||||||
if dynamic_dict:
|
if dynamic_dict:
|
||||||
src_map = Field(
|
src_map = Field(
|
||||||
use_vocab=False, dtype=torch.float,
|
use_vocab=False, dtype=torch.float,
|
||||||
|
@ -136,9 +182,20 @@ def get_fields(
|
||||||
postprocessing=make_tgt, sequential=False)
|
postprocessing=make_tgt, sequential=False)
|
||||||
fields["alignment"] = align
|
fields["alignment"] = align
|
||||||
|
|
||||||
|
if with_align:
|
||||||
|
word_align = AlignField()
|
||||||
|
fields["align"] = word_align
|
||||||
|
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
|
|
||||||
|
def patch_fields(opt, fields):
|
||||||
|
dvocab = torch.load(opt.data + '.vocab.pt')
|
||||||
|
maybe_cid_field = dvocab.get('corpus_id', None)
|
||||||
|
if maybe_cid_field is not None:
|
||||||
|
fields.update({'corpus_id': maybe_cid_field})
|
||||||
|
|
||||||
|
|
||||||
def load_old_vocab(vocab, data_type="text", dynamic_dict=False):
|
def load_old_vocab(vocab, data_type="text", dynamic_dict=False):
|
||||||
"""Update a legacy vocab/field format.
|
"""Update a legacy vocab/field format.
|
||||||
|
|
||||||
|
@ -317,7 +374,9 @@ def _build_fv_from_multifield(multifield, counters, build_fv_args,
|
||||||
def _build_fields_vocab(fields, counters, data_type, share_vocab,
|
def _build_fields_vocab(fields, counters, data_type, share_vocab,
|
||||||
vocab_size_multiple,
|
vocab_size_multiple,
|
||||||
src_vocab_size, src_words_min_frequency,
|
src_vocab_size, src_words_min_frequency,
|
||||||
tgt_vocab_size, tgt_words_min_frequency):
|
tgt_vocab_size, tgt_words_min_frequency,
|
||||||
|
subword_prefix="▁",
|
||||||
|
subword_prefix_is_joiner=False):
|
||||||
build_fv_args = defaultdict(dict)
|
build_fv_args = defaultdict(dict)
|
||||||
build_fv_args["src"] = dict(
|
build_fv_args["src"] = dict(
|
||||||
max_size=src_vocab_size, min_freq=src_words_min_frequency)
|
max_size=src_vocab_size, min_freq=src_words_min_frequency)
|
||||||
|
@ -329,6 +388,11 @@ def _build_fields_vocab(fields, counters, data_type, share_vocab,
|
||||||
counters,
|
counters,
|
||||||
build_fv_args,
|
build_fv_args,
|
||||||
size_multiple=vocab_size_multiple if not share_vocab else 1)
|
size_multiple=vocab_size_multiple if not share_vocab else 1)
|
||||||
|
|
||||||
|
if fields.get("corpus_id", False):
|
||||||
|
fields["corpus_id"].vocab = fields["corpus_id"].vocab_cls(
|
||||||
|
counters["corpus_id"])
|
||||||
|
|
||||||
if data_type == 'text':
|
if data_type == 'text':
|
||||||
src_multifield = fields["src"]
|
src_multifield = fields["src"]
|
||||||
_build_fv_from_multifield(
|
_build_fv_from_multifield(
|
||||||
|
@ -336,6 +400,7 @@ def _build_fields_vocab(fields, counters, data_type, share_vocab,
|
||||||
counters,
|
counters,
|
||||||
build_fv_args,
|
build_fv_args,
|
||||||
size_multiple=vocab_size_multiple if not share_vocab else 1)
|
size_multiple=vocab_size_multiple if not share_vocab else 1)
|
||||||
|
|
||||||
if share_vocab:
|
if share_vocab:
|
||||||
# `tgt_vocab_size` is ignored when sharing vocabularies
|
# `tgt_vocab_size` is ignored when sharing vocabularies
|
||||||
logger.info(" * merging src and tgt vocab...")
|
logger.info(" * merging src and tgt vocab...")
|
||||||
|
@ -347,9 +412,38 @@ def _build_fields_vocab(fields, counters, data_type, share_vocab,
|
||||||
vocab_size_multiple=vocab_size_multiple)
|
vocab_size_multiple=vocab_size_multiple)
|
||||||
logger.info(" * merged vocab size: %d." % len(src_field.vocab))
|
logger.info(" * merged vocab size: %d." % len(src_field.vocab))
|
||||||
|
|
||||||
|
build_noise_field(
|
||||||
|
src_multifield.base_field,
|
||||||
|
subword_prefix=subword_prefix,
|
||||||
|
is_joiner=subword_prefix_is_joiner)
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
|
|
||||||
|
def build_noise_field(src_field, subword=True,
|
||||||
|
subword_prefix="▁", is_joiner=False,
|
||||||
|
sentence_breaks=[".", "?", "!"]):
|
||||||
|
"""In place add noise related fields i.e.:
|
||||||
|
- word_start
|
||||||
|
- end_of_sentence
|
||||||
|
"""
|
||||||
|
if subword:
|
||||||
|
def is_word_start(x): return (x.startswith(subword_prefix) ^ is_joiner)
|
||||||
|
sentence_breaks = [subword_prefix + t for t in sentence_breaks]
|
||||||
|
else:
|
||||||
|
def is_word_start(x): return True
|
||||||
|
|
||||||
|
vocab_size = len(src_field.vocab)
|
||||||
|
word_start_mask = torch.zeros([vocab_size]).bool()
|
||||||
|
end_of_sentence_mask = torch.zeros([vocab_size]).bool()
|
||||||
|
for i, t in enumerate(src_field.vocab.itos):
|
||||||
|
if is_word_start(t):
|
||||||
|
word_start_mask[i] = True
|
||||||
|
if t in sentence_breaks:
|
||||||
|
end_of_sentence_mask[i] = True
|
||||||
|
src_field.word_start_mask = word_start_mask
|
||||||
|
src_field.end_of_sentence_mask = end_of_sentence_mask
|
||||||
|
|
||||||
|
|
||||||
def build_vocab(train_dataset_files, fields, data_type, share_vocab,
|
def build_vocab(train_dataset_files, fields, data_type, share_vocab,
|
||||||
src_vocab_path, src_vocab_size, src_words_min_frequency,
|
src_vocab_path, src_vocab_size, src_words_min_frequency,
|
||||||
tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency,
|
tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency,
|
||||||
|
@ -776,11 +870,14 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False):
|
||||||
to iterate over. We implement simple ordered iterator strategy here,
|
to iterate over. We implement simple ordered iterator strategy here,
|
||||||
but more sophisticated strategy like curriculum learning is ok too.
|
but more sophisticated strategy like curriculum learning is ok too.
|
||||||
"""
|
"""
|
||||||
|
dataset_glob = opt.data + '.' + corpus_type + '.[0-9]*.pt'
|
||||||
dataset_paths = list(sorted(
|
dataset_paths = list(sorted(
|
||||||
glob.glob(opt.data + '.' + corpus_type + '.[0-9]*.pt')))
|
glob.glob(dataset_glob),
|
||||||
|
key=lambda p: int(p.split(".")[-2])))
|
||||||
|
|
||||||
if not dataset_paths:
|
if not dataset_paths:
|
||||||
if is_train:
|
if is_train:
|
||||||
raise ValueError('Training data %s not found' % opt.data)
|
raise ValueError('Training data %s not found' % dataset_glob)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
if multi:
|
if multi:
|
||||||
|
|
|
@ -190,6 +190,8 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None):
|
||||||
vocab_size = len(tgt_base_field.vocab)
|
vocab_size = len(tgt_base_field.vocab)
|
||||||
pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token]
|
pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token]
|
||||||
generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx)
|
generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx)
|
||||||
|
if model_opt.share_decoder_embeddings:
|
||||||
|
generator.linear.weight = decoder.embeddings.word_lut.weight
|
||||||
|
|
||||||
# Load the model states from checkpoint or initialize them.
|
# Load the model states from checkpoint or initialize them.
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
|
@ -230,7 +232,8 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None):
|
||||||
|
|
||||||
model.generator = generator
|
model.generator = generator
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
if model_opt.model_dtype == 'fp16' and model_opt.optim == 'fusedadam':
|
||||||
|
model.half()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,5 +2,4 @@
|
||||||
from onmt.models.model_saver import build_model_saver, ModelSaver
|
from onmt.models.model_saver import build_model_saver, ModelSaver
|
||||||
from onmt.models.model import NMTModel
|
from onmt.models.model import NMTModel
|
||||||
|
|
||||||
__all__ = ["build_model_saver", "ModelSaver",
|
__all__ = ["build_model_saver", "ModelSaver", "NMTModel"]
|
||||||
"NMTModel", "check_sru_requirement"]
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ class NMTModel(nn.Module):
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
|
||||||
def forward(self, src, tgt, lengths, bptt=False):
|
def forward(self, src, tgt, lengths, bptt=False, with_align=False):
|
||||||
"""Forward propagate a `src` and `tgt` pair for training.
|
"""Forward propagate a `src` and `tgt` pair for training.
|
||||||
Possible initialized with a beginning decoder state.
|
Possible initialized with a beginning decoder state.
|
||||||
|
|
||||||
|
@ -26,10 +26,13 @@ class NMTModel(nn.Module):
|
||||||
typically for inputs this will be a padded `LongTensor`
|
typically for inputs this will be a padded `LongTensor`
|
||||||
of size ``(len, batch, features)``. However, may be an
|
of size ``(len, batch, features)``. However, may be an
|
||||||
image or other generic input depending on encoder.
|
image or other generic input depending on encoder.
|
||||||
tgt (LongTensor): A target sequence of size ``(tgt_len, batch)``.
|
tgt (LongTensor): A target sequence passed to decoder.
|
||||||
|
Size ``(tgt_len, batch, features)``.
|
||||||
lengths(LongTensor): The src lengths, pre-padding ``(batch,)``.
|
lengths(LongTensor): The src lengths, pre-padding ``(batch,)``.
|
||||||
bptt (Boolean): A flag indicating if truncated bptt is set.
|
bptt (Boolean): A flag indicating if truncated bptt is set.
|
||||||
If reset then init_state
|
If reset then init_state
|
||||||
|
with_align (Boolean): A flag indicating whether output alignment,
|
||||||
|
Only valid for transformer decoder.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(FloatTensor, dict[str, FloatTensor]):
|
(FloatTensor, dict[str, FloatTensor]):
|
||||||
|
@ -37,14 +40,15 @@ class NMTModel(nn.Module):
|
||||||
* decoder output ``(tgt_len, batch, hidden)``
|
* decoder output ``(tgt_len, batch, hidden)``
|
||||||
* dictionary attention dists of ``(tgt_len, batch, src_len)``
|
* dictionary attention dists of ``(tgt_len, batch, src_len)``
|
||||||
"""
|
"""
|
||||||
tgt = tgt[:-1] # exclude last target from inputs
|
dec_in = tgt[:-1] # exclude last target from inputs
|
||||||
|
|
||||||
enc_state, memory_bank, lengths = self.encoder(src, lengths)
|
enc_state, memory_bank, lengths = self.encoder(src, lengths)
|
||||||
|
|
||||||
if bptt is False:
|
if bptt is False:
|
||||||
self.decoder.init_state(src, memory_bank, enc_state)
|
self.decoder.init_state(src, memory_bank, enc_state)
|
||||||
dec_out, attns = self.decoder(tgt, memory_bank,
|
dec_out, attns = self.decoder(dec_in, memory_bank,
|
||||||
memory_lengths=lengths)
|
memory_lengths=lengths,
|
||||||
|
with_align=with_align)
|
||||||
return dec_out, attns
|
return dec_out, attns
|
||||||
|
|
||||||
def update_dropout(self, dropout):
|
def update_dropout(self, dropout):
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from onmt.utils.logging import logger
|
from onmt.utils.logging import logger
|
||||||
|
@ -48,18 +47,20 @@ class ModelSaverBase(object):
|
||||||
if self.keep_checkpoint == 0 or step == self.last_saved_step:
|
if self.keep_checkpoint == 0 or step == self.last_saved_step:
|
||||||
return
|
return
|
||||||
|
|
||||||
if moving_average:
|
|
||||||
save_model = deepcopy(self.model)
|
|
||||||
for avg, param in zip(moving_average, save_model.parameters()):
|
|
||||||
param.data.copy_(avg.data)
|
|
||||||
else:
|
|
||||||
save_model = self.model
|
save_model = self.model
|
||||||
|
if moving_average:
|
||||||
|
model_params_data = []
|
||||||
|
for avg, param in zip(moving_average, save_model.parameters()):
|
||||||
|
model_params_data.append(param.data)
|
||||||
|
param.data = avg.data
|
||||||
|
|
||||||
chkpt, chkpt_name = self._save(step, save_model)
|
chkpt, chkpt_name = self._save(step, save_model)
|
||||||
self.last_saved_step = step
|
self.last_saved_step = step
|
||||||
|
|
||||||
if moving_average:
|
if moving_average:
|
||||||
del save_model
|
for param_data, param in zip(model_params_data,
|
||||||
|
save_model.parameters()):
|
||||||
|
param.data = param_data
|
||||||
|
|
||||||
if self.keep_checkpoint > 0:
|
if self.keep_checkpoint > 0:
|
||||||
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
|
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
|
||||||
|
@ -97,17 +98,10 @@ class ModelSaver(ModelSaverBase):
|
||||||
"""Simple model saver to filesystem"""
|
"""Simple model saver to filesystem"""
|
||||||
|
|
||||||
def _save(self, step, model):
|
def _save(self, step, model):
|
||||||
real_model = (model.module
|
model_state_dict = model.state_dict()
|
||||||
if isinstance(model, nn.DataParallel)
|
|
||||||
else model)
|
|
||||||
real_generator = (real_model.generator.module
|
|
||||||
if isinstance(real_model.generator, nn.DataParallel)
|
|
||||||
else real_model.generator)
|
|
||||||
|
|
||||||
model_state_dict = real_model.state_dict()
|
|
||||||
model_state_dict = {k: v for k, v in model_state_dict.items()
|
model_state_dict = {k: v for k, v in model_state_dict.items()
|
||||||
if 'generator' not in k}
|
if 'generator' not in k}
|
||||||
generator_state_dict = real_generator.state_dict()
|
generator_state_dict = model.generator.state_dict()
|
||||||
|
|
||||||
# NOTE: We need to trim the vocab to remove any unk tokens that
|
# NOTE: We need to trim the vocab to remove any unk tokens that
|
||||||
# were not originally here.
|
# were not originally here.
|
||||||
|
@ -137,4 +131,5 @@ class ModelSaver(ModelSaverBase):
|
||||||
return checkpoint, checkpoint_path
|
return checkpoint, checkpoint_path
|
||||||
|
|
||||||
def _rm_checkpoint(self, name):
|
def _rm_checkpoint(self, name):
|
||||||
|
if os.path.exists(name):
|
||||||
os.remove(name)
|
os.remove(name)
|
||||||
|
|
|
@ -11,6 +11,8 @@ from onmt.modules.embeddings import Embeddings, PositionalEncoding, \
|
||||||
from onmt.modules.weight_norm import WeightNormConv2d
|
from onmt.modules.weight_norm import WeightNormConv2d
|
||||||
from onmt.modules.average_attn import AverageAttention
|
from onmt.modules.average_attn import AverageAttention
|
||||||
|
|
||||||
|
import onmt.modules.source_noise # noqa
|
||||||
|
|
||||||
__all__ = ["Elementwise", "context_gate_factory", "ContextGate",
|
__all__ = ["Elementwise", "context_gate_factory", "ContextGate",
|
||||||
"GlobalAttention", "ConvMultiStepAttention", "CopyGenerator",
|
"GlobalAttention", "ConvMultiStepAttention", "CopyGenerator",
|
||||||
"CopyGeneratorLoss", "CopyGeneratorLossCompute",
|
"CopyGeneratorLoss", "CopyGeneratorLossCompute",
|
||||||
|
|
|
@ -180,7 +180,7 @@ class GlobalAttention(nn.Module):
|
||||||
if memory_lengths is not None:
|
if memory_lengths is not None:
|
||||||
mask = sequence_mask(memory_lengths, max_len=align.size(-1))
|
mask = sequence_mask(memory_lengths, max_len=align.size(-1))
|
||||||
mask = mask.unsqueeze(1) # Make it broadcastable.
|
mask = mask.unsqueeze(1) # Make it broadcastable.
|
||||||
align.masked_fill_(1 - mask, -float('inf'))
|
align.masked_fill_(~mask, -float('inf'))
|
||||||
|
|
||||||
# Softmax or sparsemax to normalize attention weights
|
# Softmax or sparsemax to normalize attention weights
|
||||||
if self.attn_func == "softmax":
|
if self.attn_func == "softmax":
|
||||||
|
|
|
@ -92,7 +92,7 @@ class MultiHeadedAttention(nn.Module):
|
||||||
(FloatTensor, FloatTensor):
|
(FloatTensor, FloatTensor):
|
||||||
|
|
||||||
* output context vectors ``(batch, query_len, dim)``
|
* output context vectors ``(batch, query_len, dim)``
|
||||||
* one of the attention vectors ``(batch, query_len, key_len)``
|
* Attention vector in heads ``(batch, head, query_len, key_len)``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# CHECKS
|
# CHECKS
|
||||||
|
@ -219,13 +219,12 @@ class MultiHeadedAttention(nn.Module):
|
||||||
# aeq(batch, batch_)
|
# aeq(batch, batch_)
|
||||||
# aeq(d, d_)
|
# aeq(d, d_)
|
||||||
|
|
||||||
# Return one attn
|
# Return multi-head attn
|
||||||
top_attn = attn \
|
attns = attn \
|
||||||
.view(batch_size, head_count,
|
.view(batch_size, head_count,
|
||||||
query_len, key_len)[:, 0, :, :] \
|
query_len, key_len)
|
||||||
.contiguous()
|
|
||||||
|
|
||||||
return output, top_attn
|
return output, attns
|
||||||
|
|
||||||
def update_dropout(self, dropout):
|
def update_dropout(self, dropout):
|
||||||
self.dropout.p = dropout
|
self.dropout.p = dropout
|
||||||
|
|
|
@ -0,0 +1,350 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def aeq(ref, *args):
|
||||||
|
for i, e in enumerate(args):
|
||||||
|
assert ref == e, "%s != %s (element %d)" % (str(ref), str(e), i)
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseBase(object):
|
||||||
|
def __init__(self, prob, pad_idx=1, device_id="cpu",
|
||||||
|
ids_to_noise=[], **kwargs):
|
||||||
|
self.prob = prob
|
||||||
|
self.pad_idx = 1
|
||||||
|
self.skip_first = 1
|
||||||
|
self.device_id = device_id
|
||||||
|
self.ids_to_noise = set([t.item() for t in ids_to_noise])
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
return self.noise_batch(batch)
|
||||||
|
|
||||||
|
def to_device(self, t):
|
||||||
|
return t.to(torch.device(self.device_id))
|
||||||
|
|
||||||
|
def noise_batch(self, batch):
|
||||||
|
source, lengths = batch.src if isinstance(batch.src, tuple) \
|
||||||
|
else (batch.src, [None] * batch.src.size(1))
|
||||||
|
# noise_skip = batch.noise_skip
|
||||||
|
# aeq(len(batch.noise_skip) == source.size(1))
|
||||||
|
|
||||||
|
# source is [src_len x bs x feats]
|
||||||
|
skipped = source[:self.skip_first, :, :]
|
||||||
|
source = source[self.skip_first:]
|
||||||
|
for i in range(source.size(1)):
|
||||||
|
if hasattr(batch, 'corpus_id'):
|
||||||
|
corpus_id = batch.corpus_id[i]
|
||||||
|
if corpus_id.item() not in self.ids_to_noise:
|
||||||
|
continue
|
||||||
|
tokens = source[:, i, 0]
|
||||||
|
mask = tokens.ne(self.pad_idx)
|
||||||
|
|
||||||
|
masked_tokens = tokens[mask]
|
||||||
|
noisy_tokens, length = self.noise_source(
|
||||||
|
masked_tokens, length=lengths[i])
|
||||||
|
|
||||||
|
lengths[i] = length
|
||||||
|
|
||||||
|
# source might increase length so we need to resize the whole
|
||||||
|
# tensor
|
||||||
|
delta = length - (source.size(0) - self.skip_first)
|
||||||
|
if delta > 0:
|
||||||
|
pad = torch.ones([delta],
|
||||||
|
device=source.device,
|
||||||
|
dtype=source.dtype)
|
||||||
|
pad *= self.pad_idx
|
||||||
|
pad = pad.unsqueeze(1).expand(-1, 15).unsqueeze(2)
|
||||||
|
|
||||||
|
source = torch.cat([source, source])
|
||||||
|
source[:noisy_tokens.size(0), i, 0] = noisy_tokens
|
||||||
|
|
||||||
|
source = torch.cat([skipped, source])
|
||||||
|
|
||||||
|
# remove useless pad
|
||||||
|
max_len = lengths.max()
|
||||||
|
source = source[:max_len, :, :]
|
||||||
|
|
||||||
|
batch.src = source, lengths
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def noise_source(self, source, **kwargs):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class MaskNoise(NoiseBase):
|
||||||
|
def noise_batch(self, batch):
|
||||||
|
raise ValueError("MaskNoise has not been updated to tensor noise")
|
||||||
|
# def s(self, tokens):
|
||||||
|
# prob = self.prob
|
||||||
|
# r = torch.rand([len(tokens)])
|
||||||
|
# mask = False
|
||||||
|
# masked = []
|
||||||
|
# for i, tok in enumerate(tokens):
|
||||||
|
# if tok.startswith(subword_prefix):
|
||||||
|
# if r[i].item() <= prob:
|
||||||
|
# masked.append(mask_tok)
|
||||||
|
# mask = True
|
||||||
|
# else:
|
||||||
|
# masked.append(tok)
|
||||||
|
# mask = False
|
||||||
|
# else:
|
||||||
|
# if mask:
|
||||||
|
# pass
|
||||||
|
# else:
|
||||||
|
# masked.append(tok)
|
||||||
|
# return masked
|
||||||
|
|
||||||
|
|
||||||
|
class SenShufflingNoise(NoiseBase):
|
||||||
|
def __init__(self, *args, end_of_sentence_mask=None, **kwargs):
|
||||||
|
super(SenShufflingNoise, self).__init__(*args, **kwargs)
|
||||||
|
assert end_of_sentence_mask is not None
|
||||||
|
self.end_of_sentence_mask = self.to_device(end_of_sentence_mask)
|
||||||
|
|
||||||
|
def is_end_of_sentence(self, source):
|
||||||
|
return self.end_of_sentence_mask.gather(0, source)
|
||||||
|
|
||||||
|
def noise_source(self, source, length=None, **kwargs):
|
||||||
|
# aeq(source.size(0), length)
|
||||||
|
full_stops = self.is_end_of_sentence(source)
|
||||||
|
# Pretend it ends with a full stop so last span is a sentence
|
||||||
|
full_stops[-1] = 1
|
||||||
|
|
||||||
|
# Tokens that are full stops, where the previous token is not
|
||||||
|
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero() + 2
|
||||||
|
result = source.clone()
|
||||||
|
|
||||||
|
num_sentences = sentence_ends.size(0)
|
||||||
|
num_to_permute = math.ceil((num_sentences * 2 * self.prob) / 2.0)
|
||||||
|
substitutions = torch.randperm(num_sentences)[:num_to_permute]
|
||||||
|
ordering = torch.arange(0, num_sentences)
|
||||||
|
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
for i in ordering:
|
||||||
|
sentence = source[(sentence_ends[i - 1] if i >
|
||||||
|
0 else 1):sentence_ends[i]]
|
||||||
|
result[index:index + sentence.size(0)] = sentence
|
||||||
|
index += sentence.size(0)
|
||||||
|
# aeq(source.size(0), length)
|
||||||
|
return result, length
|
||||||
|
|
||||||
|
|
||||||
|
class InfillingNoise(NoiseBase):
|
||||||
|
def __init__(self, *args, infilling_poisson_lambda=3.0,
|
||||||
|
word_start_mask=None, **kwargs):
|
||||||
|
super(InfillingNoise, self).__init__(*args, **kwargs)
|
||||||
|
self.poisson_lambda = infilling_poisson_lambda
|
||||||
|
self.mask_span_distribution = self._make_poisson(self.poisson_lambda)
|
||||||
|
self.mask_idx = 0
|
||||||
|
assert word_start_mask is not None
|
||||||
|
self.word_start_mask = self.to_device(word_start_mask)
|
||||||
|
|
||||||
|
# -1: keep everything (i.e. 1 mask per token)
|
||||||
|
# 0: replace everything (i.e. no mask)
|
||||||
|
# 1: 1 mask per span
|
||||||
|
self.replace_length = 1
|
||||||
|
|
||||||
|
def _make_poisson(self, poisson_lambda):
|
||||||
|
# fairseq/data/denoising_dataset.py
|
||||||
|
_lambda = poisson_lambda
|
||||||
|
|
||||||
|
lambda_to_the_k = 1
|
||||||
|
e_to_the_minus_lambda = math.exp(-_lambda)
|
||||||
|
k_factorial = 1
|
||||||
|
ps = []
|
||||||
|
for k in range(0, 128):
|
||||||
|
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
|
||||||
|
lambda_to_the_k *= _lambda
|
||||||
|
k_factorial *= (k + 1)
|
||||||
|
if ps[-1] < 0.0000001:
|
||||||
|
break
|
||||||
|
ps = torch.tensor(ps, device=torch.device(self.device_id))
|
||||||
|
return torch.distributions.Categorical(ps)
|
||||||
|
|
||||||
|
def is_word_start(self, source):
|
||||||
|
# print("src size: ", source.size())
|
||||||
|
# print("ws size: ", self.word_start_mask.size())
|
||||||
|
# print("max: ", source.max())
|
||||||
|
# assert source.max() < self.word_start_mask.size(0)
|
||||||
|
# assert source.min() >= 0
|
||||||
|
return self.word_start_mask.gather(0, source)
|
||||||
|
|
||||||
|
def noise_source(self, source, **kwargs):
|
||||||
|
|
||||||
|
is_word_start = self.is_word_start(source)
|
||||||
|
# assert source.size() == is_word_start.size()
|
||||||
|
# aeq(source.eq(self.pad_idx).long().sum(), 0)
|
||||||
|
|
||||||
|
# we manually add this hypothesis since it's required for the rest
|
||||||
|
# of the function and kindof make sense
|
||||||
|
is_word_start[-1] = 0
|
||||||
|
|
||||||
|
p = self.prob
|
||||||
|
num_to_mask = (is_word_start.float().sum() * p).ceil().long()
|
||||||
|
num_inserts = 0
|
||||||
|
if num_to_mask == 0:
|
||||||
|
return source
|
||||||
|
|
||||||
|
if self.mask_span_distribution is not None:
|
||||||
|
lengths = self.mask_span_distribution.sample(
|
||||||
|
sample_shape=(num_to_mask,))
|
||||||
|
|
||||||
|
# Make sure we have enough to mask
|
||||||
|
cum_length = torch.cumsum(lengths, 0)
|
||||||
|
while cum_length[-1] < num_to_mask:
|
||||||
|
lengths = torch.cat([
|
||||||
|
lengths,
|
||||||
|
self.mask_span_distribution.sample(
|
||||||
|
sample_shape=(num_to_mask,))
|
||||||
|
], dim=0)
|
||||||
|
cum_length = torch.cumsum(lengths, 0)
|
||||||
|
|
||||||
|
# Trim to masking budget
|
||||||
|
i = 0
|
||||||
|
while cum_length[i] < num_to_mask:
|
||||||
|
i += 1
|
||||||
|
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
|
||||||
|
num_to_mask = i + 1
|
||||||
|
lengths = lengths[:num_to_mask]
|
||||||
|
|
||||||
|
# Handle 0-length mask (inserts) separately
|
||||||
|
lengths = lengths[lengths > 0]
|
||||||
|
num_inserts = num_to_mask - lengths.size(0)
|
||||||
|
num_to_mask -= num_inserts
|
||||||
|
if num_to_mask == 0:
|
||||||
|
return self.add_insertion_noise(
|
||||||
|
source, num_inserts / source.size(0))
|
||||||
|
# assert (lengths > 0).all()
|
||||||
|
else:
|
||||||
|
raise ValueError("Not supposed to be there")
|
||||||
|
lengths = torch.ones((num_to_mask,), device=source.device).long()
|
||||||
|
# assert is_word_start[-1] == 0
|
||||||
|
word_starts = is_word_start.nonzero()
|
||||||
|
indices = word_starts[torch.randperm(word_starts.size(0))[
|
||||||
|
:num_to_mask]].squeeze(1)
|
||||||
|
|
||||||
|
source_length = source.size(0)
|
||||||
|
# TODO why?
|
||||||
|
# assert source_length - 1 not in indices
|
||||||
|
to_keep = torch.ones(
|
||||||
|
source_length,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=source.device)
|
||||||
|
|
||||||
|
is_word_start = is_word_start.long()
|
||||||
|
# acts as a long length, so spans don't go over the end of doc
|
||||||
|
is_word_start[-1] = 10e5
|
||||||
|
if self.replace_length == 0:
|
||||||
|
to_keep[indices] = 0
|
||||||
|
else:
|
||||||
|
# keep index, but replace it with [MASK]
|
||||||
|
source[indices] = self.mask_idx
|
||||||
|
# random ratio disabled
|
||||||
|
# source[indices[mask_random]] = torch.randint(
|
||||||
|
# 1, len(self.vocab), size=(mask_random.sum(),))
|
||||||
|
|
||||||
|
# if self.mask_span_distribution is not None:
|
||||||
|
# assert len(lengths.size()) == 1
|
||||||
|
# assert lengths.size() == indices.size()
|
||||||
|
lengths -= 1
|
||||||
|
while indices.size(0) > 0:
|
||||||
|
# assert lengths.size() == indices.size()
|
||||||
|
lengths -= is_word_start[indices + 1].long()
|
||||||
|
uncompleted = lengths >= 0
|
||||||
|
indices = indices[uncompleted] + 1
|
||||||
|
|
||||||
|
# mask_random = mask_random[uncompleted]
|
||||||
|
lengths = lengths[uncompleted]
|
||||||
|
if self.replace_length != -1:
|
||||||
|
# delete token
|
||||||
|
to_keep[indices] = 0
|
||||||
|
else:
|
||||||
|
# keep index, but replace it with [MASK]
|
||||||
|
source[indices] = self.mask_idx
|
||||||
|
# random ratio disabled
|
||||||
|
# source[indices[mask_random]] = torch.randint(
|
||||||
|
# 1, len(self.vocab), size=(mask_random.sum(),))
|
||||||
|
# else:
|
||||||
|
# # A bit faster when all lengths are 1
|
||||||
|
# while indices.size(0) > 0:
|
||||||
|
# uncompleted = is_word_start[indices + 1] == 0
|
||||||
|
# indices = indices[uncompleted] + 1
|
||||||
|
# mask_random = mask_random[uncompleted]
|
||||||
|
# if self.replace_length != -1:
|
||||||
|
# # delete token
|
||||||
|
# to_keep[indices] = 0
|
||||||
|
# else:
|
||||||
|
# # keep index, but replace it with [MASK]
|
||||||
|
# source[indices] = self.mask_idx
|
||||||
|
# source[indices[mask_random]] = torch.randint(
|
||||||
|
# 1, len(self.vocab), size=(mask_random.sum(),))
|
||||||
|
|
||||||
|
# assert source_length - 1 not in indices
|
||||||
|
|
||||||
|
source = source[to_keep]
|
||||||
|
|
||||||
|
if num_inserts > 0:
|
||||||
|
source = self.add_insertion_noise(
|
||||||
|
source, num_inserts / source.size(0))
|
||||||
|
|
||||||
|
# aeq(source.eq(self.pad_idx).long().sum(), 0)
|
||||||
|
final_length = source.size(0)
|
||||||
|
return source, final_length
|
||||||
|
|
||||||
|
def add_insertion_noise(self, tokens, p):
|
||||||
|
if p == 0.0:
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
num_tokens = tokens.size(0)
|
||||||
|
n = int(math.ceil(num_tokens * p))
|
||||||
|
|
||||||
|
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
|
||||||
|
noise_mask = torch.zeros(
|
||||||
|
size=(
|
||||||
|
num_tokens + n,
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=tokens.device)
|
||||||
|
noise_mask[noise_indices] = 1
|
||||||
|
result = torch.ones([n + len(tokens)],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=tokens.device) * -1
|
||||||
|
|
||||||
|
# random ratio disabled
|
||||||
|
# num_random = int(math.ceil(n * self.random_ratio))
|
||||||
|
result[noise_indices] = self.mask_idx
|
||||||
|
# result[noise_indices[:num_random]] = torch.randint(
|
||||||
|
# low=1, high=len(self.vocab), size=(num_random,))
|
||||||
|
|
||||||
|
result[~noise_mask] = tokens
|
||||||
|
|
||||||
|
# assert (result >= 0).all()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MultiNoise(NoiseBase):
|
||||||
|
NOISES = {
|
||||||
|
"sen_shuffling": SenShufflingNoise,
|
||||||
|
"infilling": InfillingNoise,
|
||||||
|
"mask": MaskNoise
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, noises=[], probs=[], **kwargs):
|
||||||
|
assert len(noises) == len(probs)
|
||||||
|
super(MultiNoise, self).__init__(probs, **kwargs)
|
||||||
|
|
||||||
|
self.noises = []
|
||||||
|
for i, n in enumerate(noises):
|
||||||
|
cls = MultiNoise.NOISES.get(n)
|
||||||
|
if n is None:
|
||||||
|
raise ValueError("Unknown noise function '%s'" % n)
|
||||||
|
else:
|
||||||
|
noise = cls(probs[i], **kwargs)
|
||||||
|
self.noises.append(noise)
|
||||||
|
|
||||||
|
def noise_source(self, source, length=None, **kwargs):
|
||||||
|
for noise in self.noises:
|
||||||
|
source, length = noise.noise_source(
|
||||||
|
source, length=length, **kwargs)
|
||||||
|
return source, length
|
|
@ -2,6 +2,8 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import configargparse
|
import configargparse
|
||||||
|
import onmt
|
||||||
|
|
||||||
from onmt.models.sru import CheckSRU
|
from onmt.models.sru import CheckSRU
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,6 +14,149 @@ def config_opts(parser):
|
||||||
is_write_out_config_file_arg=True,
|
is_write_out_config_file_arg=True,
|
||||||
help='config file save path')
|
help='config file save path')
|
||||||
|
|
||||||
|
def global_opts(parser):
|
||||||
|
group = parser.add_argument_group('Model')
|
||||||
|
group.add('--fp32', '-fp32', action='store_true',
|
||||||
|
help="Force the model to be in FP32 "
|
||||||
|
"because FP16 is very slow on GTX1080(ti).")
|
||||||
|
group.add('--avg_raw_probs', '-avg_raw_probs', action='store_true',
|
||||||
|
help="If this is set, during ensembling scores from "
|
||||||
|
"different models will be combined by averaging their "
|
||||||
|
"raw probabilities and then taking the log. Otherwise, "
|
||||||
|
"the log probabilities will be averaged directly. "
|
||||||
|
"Necessary for models whose output layers can assign "
|
||||||
|
"zero probability.")
|
||||||
|
|
||||||
|
group = parser.add_argument_group('Data')
|
||||||
|
group.add('--data_type', '-data_type', default="text",
|
||||||
|
help="Type of the source input. Options: [text|img].")
|
||||||
|
|
||||||
|
group.add('--shard_size', '-shard_size', type=int, default=10000,
|
||||||
|
help="Divide src and tgt (if applicable) into "
|
||||||
|
"smaller multiple src and tgt files, then "
|
||||||
|
"build shards, each shard will have "
|
||||||
|
"opt.shard_size samples except last shard. "
|
||||||
|
"shard_size=0 means no segmentation "
|
||||||
|
"shard_size>0 means segment dataset into multiple shards, "
|
||||||
|
"each shard has shard_size samples")
|
||||||
|
group.add('--output', '-output', default='pred.txt',
|
||||||
|
help="Path to output the predictions (each line will "
|
||||||
|
"be the decoded sequence")
|
||||||
|
group.add('--report_align', '-report_align', action='store_true',
|
||||||
|
help="Report alignment for each translation.")
|
||||||
|
group.add('--report_time', '-report_time', action='store_true',
|
||||||
|
help="Report some translation time metrics")
|
||||||
|
|
||||||
|
group = parser.add_argument_group('Random Sampling')
|
||||||
|
group.add('--random_sampling_topk', '-random_sampling_topk',
|
||||||
|
default=1, type=int,
|
||||||
|
help="Set this to -1 to do random sampling from full "
|
||||||
|
"distribution. Set this to value k>1 to do random "
|
||||||
|
"sampling restricted to the k most likely next tokens. "
|
||||||
|
"Set this to 1 to use argmax or for doing beam "
|
||||||
|
"search.")
|
||||||
|
group.add('--random_sampling_temp', '-random_sampling_temp',
|
||||||
|
default=1., type=float,
|
||||||
|
help="If doing random sampling, divide the logits by "
|
||||||
|
"this before computing softmax during decoding.")
|
||||||
|
group.add('--seed', '-seed', type=int, default=829,
|
||||||
|
help="Random seed")
|
||||||
|
|
||||||
|
group = parser.add_argument_group('Beam')
|
||||||
|
group.add('--beam_size', '-beam_size', type=int, default=5,
|
||||||
|
help='Beam size')
|
||||||
|
group.add('--min_length', '-min_length', type=int, default=0,
|
||||||
|
help='Minimum prediction length')
|
||||||
|
group.add('--max_length', '-max_length', type=int, default=100,
|
||||||
|
help='Maximum prediction length.')
|
||||||
|
group.add('--max_sent_length', '-max_sent_length', action=DeprecateAction,
|
||||||
|
help="Deprecated, use `-max_length` instead")
|
||||||
|
|
||||||
|
# Alpha and Beta values for Google Length + Coverage penalty
|
||||||
|
# Described here: https://arxiv.org/pdf/1609.08144.pdf, Section 7
|
||||||
|
group.add('--stepwise_penalty', '-stepwise_penalty', action='store_true',
|
||||||
|
help="Apply penalty at every decoding step. "
|
||||||
|
"Helpful for summary penalty.")
|
||||||
|
group.add('--length_penalty', '-length_penalty', default='none',
|
||||||
|
choices=['none', 'wu', 'avg'],
|
||||||
|
help="Length Penalty to use.")
|
||||||
|
group.add('--ratio', '-ratio', type=float, default=-0.,
|
||||||
|
help="Ratio based beam stop condition")
|
||||||
|
group.add('--coverage_penalty', '-coverage_penalty', default='none',
|
||||||
|
choices=['none', 'wu', 'summary'],
|
||||||
|
help="Coverage Penalty to use.")
|
||||||
|
group.add('--alpha', '-alpha', type=float, default=0.,
|
||||||
|
help="Google NMT length penalty parameter "
|
||||||
|
"(higher = longer generation)")
|
||||||
|
group.add('--beta', '-beta', type=float, default=-0.,
|
||||||
|
help="Coverage penalty parameter")
|
||||||
|
group.add('--block_ngram_repeat', '-block_ngram_repeat',
|
||||||
|
type=int, default=0,
|
||||||
|
help='Block repetition of ngrams during decoding.')
|
||||||
|
group.add('--ignore_when_blocking', '-ignore_when_blocking',
|
||||||
|
nargs='+', type=str, default=[],
|
||||||
|
help="Ignore these strings when blocking repeats. "
|
||||||
|
"You want to block sentence delimiters.")
|
||||||
|
group.add('--replace_unk', '-replace_unk', action="store_true",
|
||||||
|
help="Replace the generated UNK tokens with the "
|
||||||
|
"source token that had highest attention weight. If "
|
||||||
|
"phrase_table is provided, it will look up the "
|
||||||
|
"identified source token and give the corresponding "
|
||||||
|
"target token. If it is not provided (or the identified "
|
||||||
|
"source token does not exist in the table), then it "
|
||||||
|
"will copy the source token.")
|
||||||
|
group.add('--phrase_table', '-phrase_table', type=str, default="",
|
||||||
|
help="If phrase_table is provided (with replace_unk), it will "
|
||||||
|
"look up the identified source token and give the "
|
||||||
|
"corresponding target token. If it is not provided "
|
||||||
|
"(or the identified source token does not exist in "
|
||||||
|
"the table), then it will copy the source token.")
|
||||||
|
|
||||||
|
group = parser.add_argument_group('Logging')
|
||||||
|
group.add('--verbose', '-verbose', action="store_true",
|
||||||
|
help='Print scores and predictions for each sentence')
|
||||||
|
group.add('--log_file', '-log_file', type=str, default="",
|
||||||
|
help="Output logs to a file under this path.")
|
||||||
|
group.add('--log_file_level', '-log_file_level', type=str,
|
||||||
|
action=StoreLoggingLevelAction,
|
||||||
|
choices=StoreLoggingLevelAction.CHOICES,
|
||||||
|
default="0")
|
||||||
|
group.add('--attn_debug', '-attn_debug', action="store_true",
|
||||||
|
help='Print best attn for each word')
|
||||||
|
group.add('--align_debug', '-align_debug', action="store_true",
|
||||||
|
help='Print best align for each word')
|
||||||
|
group.add('--dump_beam', '-dump_beam', type=str, default="",
|
||||||
|
help='File to dump beam information to.')
|
||||||
|
group.add('--n_best', '-n_best', type=int, default=1,
|
||||||
|
help="If verbose is set, will output the n_best "
|
||||||
|
"decoded sentences")
|
||||||
|
|
||||||
|
group = parser.add_argument_group('Efficiency')
|
||||||
|
group.add('--batch_size', '-batch_size', type=int, default=30,
|
||||||
|
help='Batch size')
|
||||||
|
group.add('--batch_type', '-batch_type', default='sents',
|
||||||
|
choices=["sents", "tokens"],
|
||||||
|
help="Batch grouping for batch_size. Standard "
|
||||||
|
"is sents. Tokens will do dynamic batching")
|
||||||
|
group.add('--gpu', '-gpu', type=int, default=-1,
|
||||||
|
help="Device to run on")
|
||||||
|
|
||||||
|
# Options most relevant to speech.
|
||||||
|
group = parser.add_argument_group('Speech')
|
||||||
|
group.add('--sample_rate', '-sample_rate', type=int, default=16000,
|
||||||
|
help="Sample rate.")
|
||||||
|
group.add('--window_size', '-window_size', type=float, default=.02,
|
||||||
|
help='Window size for spectrogram in seconds')
|
||||||
|
group.add('--window_stride', '-window_stride', type=float, default=.01,
|
||||||
|
help='Window stride for spectrogram in seconds')
|
||||||
|
group.add('--window', '-window', default='hamming',
|
||||||
|
help='Window type for spectrogram generation')
|
||||||
|
|
||||||
|
# Option most relevant to image input
|
||||||
|
group.add('--image_channel_size', '-image_channel_size',
|
||||||
|
type=int, default=3, choices=[3, 1],
|
||||||
|
help="Using grayscale image can training "
|
||||||
|
"model faster and smaller")
|
||||||
|
|
||||||
def model_opts(parser):
|
def model_opts(parser):
|
||||||
"""
|
"""
|
||||||
|
@ -69,10 +214,10 @@ def model_opts(parser):
|
||||||
help='Data type of the model.')
|
help='Data type of the model.')
|
||||||
|
|
||||||
group.add('--encoder_type', '-encoder_type', type=str, default='rnn',
|
group.add('--encoder_type', '-encoder_type', type=str, default='rnn',
|
||||||
choices=['rnn', 'brnn', 'mean', 'transformer', 'cnn'],
|
choices=['rnn', 'brnn', 'ggnn', 'mean', 'transformer', 'cnn'],
|
||||||
help="Type of encoder layer to use. Non-RNN layers "
|
help="Type of encoder layer to use. Non-RNN layers "
|
||||||
"are experimental. Options are "
|
"are experimental. Options are "
|
||||||
"[rnn|brnn|mean|transformer|cnn].")
|
"[rnn|brnn|ggnn|mean|transformer|cnn].")
|
||||||
group.add('--decoder_type', '-decoder_type', type=str, default='rnn',
|
group.add('--decoder_type', '-decoder_type', type=str, default='rnn',
|
||||||
choices=['rnn', 'transformer', 'cnn'],
|
choices=['rnn', 'transformer', 'cnn'],
|
||||||
help="Type of decoder layer to use. Non-RNN layers "
|
help="Type of decoder layer to use. Non-RNN layers "
|
||||||
|
@ -128,6 +273,27 @@ def model_opts(parser):
|
||||||
help="Type of context gate to use. "
|
help="Type of context gate to use. "
|
||||||
"Do not select for no context gate.")
|
"Do not select for no context gate.")
|
||||||
|
|
||||||
|
# The following options (bridge_extra_node to src_vocab) are used
|
||||||
|
# for training with --encoder_type ggnn (Gated Graph Neural Network).
|
||||||
|
group.add('--bridge_extra_node', '-bridge_extra_node',
|
||||||
|
type=bool, default=True,
|
||||||
|
help='Graph encoder bridges only extra node to decoder as input')
|
||||||
|
group.add('--bidir_edges', '-bidir_edges', type=bool, default=True,
|
||||||
|
help='Graph encoder autogenerates bidirectional edges')
|
||||||
|
group.add('--state_dim', '-state_dim', type=int, default=512,
|
||||||
|
help='Number of state dimensions in the graph encoder')
|
||||||
|
group.add('--n_edge_types', '-n_edge_types', type=int, default=2,
|
||||||
|
help='Number of edge types in the graph encoder')
|
||||||
|
group.add('--n_node', '-n_node', type=int, default=2,
|
||||||
|
help='Number of nodes in the graph encoder')
|
||||||
|
group.add('--n_steps', '-n_steps', type=int, default=2,
|
||||||
|
help='Number of steps to advance graph encoder')
|
||||||
|
# The ggnn uses src_vocab during training because the graph is built
|
||||||
|
# using edge information which requires parsing the input sequence.
|
||||||
|
group.add('--src_vocab', '-src_vocab', default="",
|
||||||
|
help="Path to an existing source vocabulary. Format: "
|
||||||
|
"one word per line.")
|
||||||
|
|
||||||
# Attention options
|
# Attention options
|
||||||
group = parser.add_argument_group('Model- Attention')
|
group = parser.add_argument_group('Model- Attention')
|
||||||
group.add('--global_attention', '-global_attention',
|
group.add('--global_attention', '-global_attention',
|
||||||
|
@ -154,7 +320,22 @@ def model_opts(parser):
|
||||||
group.add('--aan_useffn', '-aan_useffn', action="store_true",
|
group.add('--aan_useffn', '-aan_useffn', action="store_true",
|
||||||
help='Turn on the FFN layer in the AAN decoder')
|
help='Turn on the FFN layer in the AAN decoder')
|
||||||
|
|
||||||
|
# Alignement options
|
||||||
|
group = parser.add_argument_group('Model - Alignement')
|
||||||
|
group.add('--lambda_align', '-lambda_align', type=float, default=0.0,
|
||||||
|
help="Lambda value for alignement loss of Garg et al (2019)"
|
||||||
|
"For more detailed information, see: "
|
||||||
|
"https://arxiv.org/abs/1909.02074")
|
||||||
|
group.add('--alignment_layer', '-alignment_layer', type=int, default=-3,
|
||||||
|
help='Layer number which has to be supervised.')
|
||||||
|
group.add('--alignment_heads', '-alignment_heads', type=int, default=0,
|
||||||
|
help='N. of cross attention heads per layer to supervised with')
|
||||||
|
group.add('--full_context_alignment', '-full_context_alignment',
|
||||||
|
action="store_true",
|
||||||
|
help='Whether alignment is conditioned on full target context.')
|
||||||
|
|
||||||
# Generator and loss options.
|
# Generator and loss options.
|
||||||
|
group = parser.add_argument_group('Generator')
|
||||||
group.add('--copy_attn', '-copy_attn', action="store_true",
|
group.add('--copy_attn', '-copy_attn', action="store_true",
|
||||||
help='Train copy attention layer.')
|
help='Train copy attention layer.')
|
||||||
group.add('--copy_attn_type', '-copy_attn_type',
|
group.add('--copy_attn_type', '-copy_attn_type',
|
||||||
|
@ -181,7 +362,7 @@ def model_opts(parser):
|
||||||
group.add('--loss_scale', '-loss_scale', type=float, default=0,
|
group.add('--loss_scale', '-loss_scale', type=float, default=0,
|
||||||
help="For FP16 training, the static loss scale to use. If not "
|
help="For FP16 training, the static loss scale to use. If not "
|
||||||
"set, the loss scale is dynamically computed.")
|
"set, the loss scale is dynamically computed.")
|
||||||
group.add('--apex_opt_level', '-apex_opt_level', type=str, default="O2",
|
group.add('--apex_opt_level', '-apex_opt_level', type=str, default="O1",
|
||||||
choices=["O0", "O1", "O2", "O3"],
|
choices=["O0", "O1", "O2", "O3"],
|
||||||
help="For FP16 training, the opt_level to use."
|
help="For FP16 training, the opt_level to use."
|
||||||
"See https://nvidia.github.io/apex/amp.html#opt-levels.")
|
"See https://nvidia.github.io/apex/amp.html#opt-levels.")
|
||||||
|
@ -199,12 +380,17 @@ def preprocess_opts(parser):
|
||||||
help="Path(s) to the training source data")
|
help="Path(s) to the training source data")
|
||||||
group.add('--train_tgt', '-train_tgt', required=True, nargs='+',
|
group.add('--train_tgt', '-train_tgt', required=True, nargs='+',
|
||||||
help="Path(s) to the training target data")
|
help="Path(s) to the training target data")
|
||||||
|
group.add('--train_align', '-train_align', nargs='+', default=[None],
|
||||||
|
help="Path(s) to the training src-tgt alignment")
|
||||||
group.add('--train_ids', '-train_ids', nargs='+', default=[None],
|
group.add('--train_ids', '-train_ids', nargs='+', default=[None],
|
||||||
help="ids to name training shards, used for corpus weighting")
|
help="ids to name training shards, used for corpus weighting")
|
||||||
|
|
||||||
group.add('--valid_src', '-valid_src',
|
group.add('--valid_src', '-valid_src',
|
||||||
help="Path to the validation source data")
|
help="Path to the validation source data")
|
||||||
group.add('--valid_tgt', '-valid_tgt',
|
group.add('--valid_tgt', '-valid_tgt',
|
||||||
help="Path to the validation target data")
|
help="Path to the validation target data")
|
||||||
|
group.add('--valid_align', '-valid_align', default=None,
|
||||||
|
help="Path(s) to the validation src-tgt alignment")
|
||||||
|
|
||||||
group.add('--src_dir', '-src_dir', default="",
|
group.add('--src_dir', '-src_dir', default="",
|
||||||
help="Source directory for image or audio files.")
|
help="Source directory for image or audio files.")
|
||||||
|
@ -224,6 +410,9 @@ def preprocess_opts(parser):
|
||||||
"shard_size>0 means segment dataset into multiple shards, "
|
"shard_size>0 means segment dataset into multiple shards, "
|
||||||
"each shard has shard_size samples")
|
"each shard has shard_size samples")
|
||||||
|
|
||||||
|
group.add('--num_threads', '-num_threads', type=int, default=1,
|
||||||
|
help="Number of shards to build in parallel.")
|
||||||
|
|
||||||
group.add('--overwrite', '-overwrite', action="store_true",
|
group.add('--overwrite', '-overwrite', action="store_true",
|
||||||
help="Overwrite existing shards if any.")
|
help="Overwrite existing shards if any.")
|
||||||
|
|
||||||
|
@ -310,6 +499,15 @@ def preprocess_opts(parser):
|
||||||
help="Using grayscale image can training "
|
help="Using grayscale image can training "
|
||||||
"model faster and smaller")
|
"model faster and smaller")
|
||||||
|
|
||||||
|
# Options for experimental source noising (BART style)
|
||||||
|
group = parser.add_argument_group('Noise')
|
||||||
|
group.add('--subword_prefix', '-subword_prefix',
|
||||||
|
type=str, default="▁",
|
||||||
|
help="subword prefix to build wordstart mask")
|
||||||
|
group.add('--subword_prefix_is_joiner', '-subword_prefix_is_joiner',
|
||||||
|
action='store_true',
|
||||||
|
help="mask will need to be inverted if prefix is joiner")
|
||||||
|
|
||||||
|
|
||||||
def train_opts(parser):
|
def train_opts(parser):
|
||||||
""" Training and saving options """
|
""" Training and saving options """
|
||||||
|
@ -324,6 +522,8 @@ def train_opts(parser):
|
||||||
group.add('--data_weights', '-data_weights', type=int, nargs='+',
|
group.add('--data_weights', '-data_weights', type=int, nargs='+',
|
||||||
default=[1], help="""Weights of different corpora,
|
default=[1], help="""Weights of different corpora,
|
||||||
should follow the same order as in -data_ids.""")
|
should follow the same order as in -data_ids.""")
|
||||||
|
group.add('--data_to_noise', '-data_to_noise', nargs='+', default=[],
|
||||||
|
help="IDs of datasets on which to apply noise.")
|
||||||
|
|
||||||
group.add('--save_model', '-save_model', default='model',
|
group.add('--save_model', '-save_model', default='model',
|
||||||
help="Model filename (the model will be saved as "
|
help="Model filename (the model will be saved as "
|
||||||
|
@ -352,7 +552,7 @@ def train_opts(parser):
|
||||||
help="IP of master for torch.distributed training.")
|
help="IP of master for torch.distributed training.")
|
||||||
group.add('--master_port', '-master_port', default=10000, type=int,
|
group.add('--master_port', '-master_port', default=10000, type=int,
|
||||||
help="Port of master for torch.distributed training.")
|
help="Port of master for torch.distributed training.")
|
||||||
group.add('--queue_size', '-queue_size', default=400, type=int,
|
group.add('--queue_size', '-queue_size', default=40, type=int,
|
||||||
help="Size of queue for each process in producer/consumer")
|
help="Size of queue for each process in producer/consumer")
|
||||||
|
|
||||||
group.add('--seed', '-seed', type=int, default=-1,
|
group.add('--seed', '-seed', type=int, default=-1,
|
||||||
|
@ -472,10 +672,9 @@ def train_opts(parser):
|
||||||
'Typically a value of 0.999 is recommended, as this is '
|
'Typically a value of 0.999 is recommended, as this is '
|
||||||
'the value suggested by the original paper describing '
|
'the value suggested by the original paper describing '
|
||||||
'Adam, and is also the value adopted in other frameworks '
|
'Adam, and is also the value adopted in other frameworks '
|
||||||
'such as Tensorflow and Kerras, i.e. see: '
|
'such as Tensorflow and Keras, i.e. see: '
|
||||||
'https://www.tensorflow.org/api_docs/python/tf/train/Adam'
|
'https://www.tensorflow.org/api_docs/python/tf/train/Adam'
|
||||||
'Optimizer or '
|
'Optimizer or https://keras.io/optimizers/ . '
|
||||||
'https://keras.io/optimizers/ . '
|
|
||||||
'Whereas recently the paper "Attention is All You Need" '
|
'Whereas recently the paper "Attention is All You Need" '
|
||||||
'suggested a value of 0.98 for beta2, this parameter may '
|
'suggested a value of 0.98 for beta2, this parameter may '
|
||||||
'not work well for normal models / default '
|
'not work well for normal models / default '
|
||||||
|
@ -498,6 +697,12 @@ def train_opts(parser):
|
||||||
help="Step for moving average. "
|
help="Step for moving average. "
|
||||||
"Default is every update, "
|
"Default is every update, "
|
||||||
"if -average_decay is set.")
|
"if -average_decay is set.")
|
||||||
|
group.add("--src_noise", "-src_noise", type=str, nargs='+',
|
||||||
|
default=[],
|
||||||
|
choices=onmt.modules.source_noise.MultiNoise.NOISES.keys())
|
||||||
|
group.add("--src_noise_prob", "-src_noise_prob", type=float, nargs='+',
|
||||||
|
default=[],
|
||||||
|
help="Probabilities of src_noise functions")
|
||||||
|
|
||||||
# learning rate
|
# learning rate
|
||||||
group = parser.add_argument_group('Optimization- Rate')
|
group = parser.add_argument_group('Optimization- Rate')
|
||||||
|
@ -536,10 +741,10 @@ def train_opts(parser):
|
||||||
help="Send logs to this crayon server.")
|
help="Send logs to this crayon server.")
|
||||||
group.add('--exp', '-exp', type=str, default="",
|
group.add('--exp', '-exp', type=str, default="",
|
||||||
help="Name of the experiment for logging.")
|
help="Name of the experiment for logging.")
|
||||||
# Use TensorboardX for visualization during training
|
# Use Tensorboard for visualization during training
|
||||||
group.add('--tensorboard', '-tensorboard', action="store_true",
|
group.add('--tensorboard', '-tensorboard', action="store_true",
|
||||||
help="Use tensorboardX for visualization during training. "
|
help="Use tensorboard for visualization during training. "
|
||||||
"Must have the library tensorboardX.")
|
"Must have the library tensorboard >= 1.14.")
|
||||||
group.add("--tensorboard_log_dir", "-tensorboard_log_dir",
|
group.add("--tensorboard_log_dir", "-tensorboard_log_dir",
|
||||||
type=str, default="runs/onmt",
|
type=str, default="runs/onmt",
|
||||||
help="Log directory for Tensorboard. "
|
help="Log directory for Tensorboard. "
|
||||||
|
@ -600,12 +805,8 @@ def translate_opts(parser):
|
||||||
group.add('--output', '-output', default='pred.txt',
|
group.add('--output', '-output', default='pred.txt',
|
||||||
help="Path to output the predictions (each line will "
|
help="Path to output the predictions (each line will "
|
||||||
"be the decoded sequence")
|
"be the decoded sequence")
|
||||||
group.add('--report_bleu', '-report_bleu', action='store_true',
|
group.add('--report_align', '-report_align', action='store_true',
|
||||||
help="Report bleu score after translation, "
|
help="Report alignment for each translation.")
|
||||||
"call tools/multi-bleu.perl on command line")
|
|
||||||
group.add('--report_rouge', '-report_rouge', action='store_true',
|
|
||||||
help="Report rouge 1/2/3/L/SU4 score after translation "
|
|
||||||
"call tools/test_rouge.py on command line")
|
|
||||||
group.add('--report_time', '-report_time', action='store_true',
|
group.add('--report_time', '-report_time', action='store_true',
|
||||||
help="Report some translation time metrics")
|
help="Report some translation time metrics")
|
||||||
|
|
||||||
|
@ -690,6 +891,8 @@ def translate_opts(parser):
|
||||||
default="0")
|
default="0")
|
||||||
group.add('--attn_debug', '-attn_debug', action="store_true",
|
group.add('--attn_debug', '-attn_debug', action="store_true",
|
||||||
help='Print best attn for each word')
|
help='Print best attn for each word')
|
||||||
|
group.add('--align_debug', '-align_debug', action="store_true",
|
||||||
|
help='Print best align for each word')
|
||||||
group.add('--dump_beam', '-dump_beam', type=str, default="",
|
group.add('--dump_beam', '-dump_beam', type=str, default="",
|
||||||
help='File to dump beam information to.')
|
help='File to dump beam information to.')
|
||||||
group.add('--n_best', '-n_best', type=int, default=1,
|
group.add('--n_best', '-n_best', type=int, default=1,
|
||||||
|
@ -699,6 +902,10 @@ def translate_opts(parser):
|
||||||
group = parser.add_argument_group('Efficiency')
|
group = parser.add_argument_group('Efficiency')
|
||||||
group.add('--batch_size', '-batch_size', type=int, default=30,
|
group.add('--batch_size', '-batch_size', type=int, default=30,
|
||||||
help='Batch size')
|
help='Batch size')
|
||||||
|
group.add('--batch_type', '-batch_type', default='sents',
|
||||||
|
choices=["sents", "tokens"],
|
||||||
|
help="Batch grouping for batch_size. Standard "
|
||||||
|
"is sents. Tokens will do dynamic batching")
|
||||||
group.add('--gpu', '-gpu', type=int, default=-1,
|
group.add('--gpu', '-gpu', type=int, default=-1,
|
||||||
help="Device to run on")
|
help="Device to run on")
|
||||||
|
|
||||||
|
|
|
@ -1,383 +0,0 @@
|
||||||
import unittest
|
|
||||||
from onmt.translate.beam import Beam, GNMTGlobalScorer
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalScorerStub(object):
|
|
||||||
def update_global_state(self, beam):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def score(self, beam, scores):
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
class TestBeam(unittest.TestCase):
|
|
||||||
BLOCKED_SCORE = -10e20
|
|
||||||
|
|
||||||
def test_advance_with_all_repeats_gets_blocked(self):
|
|
||||||
# all beams repeat (beam >= 1 repeat dummy scores)
|
|
||||||
beam_sz = 5
|
|
||||||
n_words = 100
|
|
||||||
repeat_idx = 47
|
|
||||||
ngram_repeat = 3
|
|
||||||
beam = Beam(beam_sz, 0, 1, 2, n_best=2,
|
|
||||||
exclusion_tokens=set(),
|
|
||||||
global_scorer=GlobalScorerStub(),
|
|
||||||
block_ngram_repeat=ngram_repeat)
|
|
||||||
for i in range(ngram_repeat + 4):
|
|
||||||
# predict repeat_idx over and over again
|
|
||||||
word_probs = torch.full((beam_sz, n_words), -float('inf'))
|
|
||||||
word_probs[0, repeat_idx] = 0
|
|
||||||
attns = torch.randn(beam_sz)
|
|
||||||
beam.advance(word_probs, attns)
|
|
||||||
if i <= ngram_repeat:
|
|
||||||
self.assertTrue(
|
|
||||||
beam.scores.equal(
|
|
||||||
torch.tensor(
|
|
||||||
[0] + [-float('inf')] * (beam_sz - 1))))
|
|
||||||
else:
|
|
||||||
self.assertTrue(
|
|
||||||
beam.scores.equal(torch.tensor(
|
|
||||||
[self.BLOCKED_SCORE] * beam_sz)))
|
|
||||||
|
|
||||||
def test_advance_with_some_repeats_gets_blocked(self):
|
|
||||||
# beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
|
|
||||||
beam_sz = 5
|
|
||||||
n_words = 100
|
|
||||||
repeat_idx = 47
|
|
||||||
ngram_repeat = 3
|
|
||||||
beam = Beam(beam_sz, 0, 1, 2, n_best=2,
|
|
||||||
exclusion_tokens=set(),
|
|
||||||
global_scorer=GlobalScorerStub(),
|
|
||||||
block_ngram_repeat=ngram_repeat)
|
|
||||||
for i in range(ngram_repeat + 4):
|
|
||||||
# non-interesting beams are going to get dummy values
|
|
||||||
word_probs = torch.full((beam_sz, n_words), -float('inf'))
|
|
||||||
if i == 0:
|
|
||||||
# on initial round, only predicted scores for beam 0
|
|
||||||
# matter. Make two predictions. Top one will be repeated
|
|
||||||
# in beam zero, second one will live on in beam 1.
|
|
||||||
word_probs[0, repeat_idx] = -0.1
|
|
||||||
word_probs[0, repeat_idx + i + 1] = -2.3
|
|
||||||
else:
|
|
||||||
# predict the same thing in beam 0
|
|
||||||
word_probs[0, repeat_idx] = 0
|
|
||||||
# continue pushing around what beam 1 predicts
|
|
||||||
word_probs[1, repeat_idx + i + 1] = 0
|
|
||||||
attns = torch.randn(beam_sz)
|
|
||||||
beam.advance(word_probs, attns)
|
|
||||||
if i <= ngram_repeat:
|
|
||||||
self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
|
|
||||||
self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE))
|
|
||||||
else:
|
|
||||||
# now beam 0 dies (along with the others), beam 1 -> beam 0
|
|
||||||
self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
|
|
||||||
self.assertTrue(
|
|
||||||
beam.scores[1:].equal(torch.tensor(
|
|
||||||
[self.BLOCKED_SCORE] * (beam_sz - 1))))
|
|
||||||
|
|
||||||
def test_repeating_excluded_index_does_not_die(self):
|
|
||||||
# beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx)
|
|
||||||
beam_sz = 5
|
|
||||||
n_words = 100
|
|
||||||
repeat_idx = 47 # will be repeated and should be blocked
|
|
||||||
repeat_idx_ignored = 7 # will be repeated and should not be blocked
|
|
||||||
ngram_repeat = 3
|
|
||||||
beam = Beam(beam_sz, 0, 1, 2, n_best=2,
|
|
||||||
exclusion_tokens=set([repeat_idx_ignored]),
|
|
||||||
global_scorer=GlobalScorerStub(),
|
|
||||||
block_ngram_repeat=ngram_repeat)
|
|
||||||
for i in range(ngram_repeat + 4):
|
|
||||||
# non-interesting beams are going to get dummy values
|
|
||||||
word_probs = torch.full((beam_sz, n_words), -float('inf'))
|
|
||||||
if i == 0:
|
|
||||||
word_probs[0, repeat_idx] = -0.1
|
|
||||||
word_probs[0, repeat_idx + i + 1] = -2.3
|
|
||||||
word_probs[0, repeat_idx_ignored] = -5.0
|
|
||||||
else:
|
|
||||||
# predict the same thing in beam 0
|
|
||||||
word_probs[0, repeat_idx] = 0
|
|
||||||
# continue pushing around what beam 1 predicts
|
|
||||||
word_probs[1, repeat_idx + i + 1] = 0
|
|
||||||
# predict the allowed-repeat again in beam 2
|
|
||||||
word_probs[2, repeat_idx_ignored] = 0
|
|
||||||
attns = torch.randn(beam_sz)
|
|
||||||
beam.advance(word_probs, attns)
|
|
||||||
if i <= ngram_repeat:
|
|
||||||
self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
|
|
||||||
self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE))
|
|
||||||
self.assertFalse(beam.scores[2].eq(self.BLOCKED_SCORE))
|
|
||||||
else:
|
|
||||||
# now beam 0 dies, beam 1 -> beam 0, beam 2 -> beam 1
|
|
||||||
# and the rest die
|
|
||||||
self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
|
|
||||||
# since all preds after i=0 are 0, we can check
|
|
||||||
# that the beam is the correct idx by checking that
|
|
||||||
# the curr score is the initial score
|
|
||||||
self.assertTrue(beam.scores[0].eq(-2.3))
|
|
||||||
self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE))
|
|
||||||
self.assertTrue(beam.scores[1].eq(-5.0))
|
|
||||||
self.assertTrue(
|
|
||||||
beam.scores[2:].equal(torch.tensor(
|
|
||||||
[self.BLOCKED_SCORE] * (beam_sz - 2))))
|
|
||||||
|
|
||||||
def test_doesnt_predict_eos_if_shorter_than_min_len(self):
|
|
||||||
# beam 0 will always predict EOS. The other beams will predict
|
|
||||||
# non-eos scores.
|
|
||||||
# this is also a test that when block_ngram_repeat=0,
|
|
||||||
# repeating is acceptable
|
|
||||||
beam_sz = 5
|
|
||||||
n_words = 100
|
|
||||||
_non_eos_idxs = [47, 51, 13, 88, 99]
|
|
||||||
valid_score_dist = torch.log_softmax(torch.tensor(
|
|
||||||
[6., 5., 4., 3., 2., 1.]), dim=0)
|
|
||||||
min_length = 5
|
|
||||||
eos_idx = 2
|
|
||||||
beam = Beam(beam_sz, 0, 1, eos_idx, n_best=2,
|
|
||||||
exclusion_tokens=set(),
|
|
||||||
min_length=min_length,
|
|
||||||
global_scorer=GlobalScorerStub(),
|
|
||||||
block_ngram_repeat=0)
|
|
||||||
for i in range(min_length + 4):
|
|
||||||
# non-interesting beams are going to get dummy values
|
|
||||||
word_probs = torch.full((beam_sz, n_words), -float('inf'))
|
|
||||||
if i == 0:
|
|
||||||
# "best" prediction is eos - that should be blocked
|
|
||||||
word_probs[0, eos_idx] = valid_score_dist[0]
|
|
||||||
# include at least beam_sz predictions OTHER than EOS
|
|
||||||
# that are greater than -1e20
|
|
||||||
for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
|
|
||||||
word_probs[0, j] = score
|
|
||||||
else:
|
|
||||||
# predict eos in beam 0
|
|
||||||
word_probs[0, eos_idx] = valid_score_dist[0]
|
|
||||||
# provide beam_sz other good predictions
|
|
||||||
for k, (j, score) in enumerate(
|
|
||||||
zip(_non_eos_idxs, valid_score_dist[1:])):
|
|
||||||
beam_idx = min(beam_sz-1, k)
|
|
||||||
word_probs[beam_idx, j] = score
|
|
||||||
|
|
||||||
attns = torch.randn(beam_sz)
|
|
||||||
beam.advance(word_probs, attns)
|
|
||||||
if i < min_length:
|
|
||||||
expected_score_dist = (i+1) * valid_score_dist[1:]
|
|
||||||
self.assertTrue(beam.scores.allclose(expected_score_dist))
|
|
||||||
elif i == min_length:
|
|
||||||
# now the top beam has ended and no others have
|
|
||||||
# first beam finished had length beam.min_length
|
|
||||||
self.assertEqual(beam.finished[0][1], beam.min_length + 1)
|
|
||||||
# first beam finished was 0
|
|
||||||
self.assertEqual(beam.finished[0][2], 0)
|
|
||||||
else: # i > min_length
|
|
||||||
# not of interest, but want to make sure it keeps running
|
|
||||||
# since only beam 0 terminates and n_best = 2
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
|
|
||||||
# this is also a test that when block_ngram_repeat=0,
|
|
||||||
# repeating is acceptable
|
|
||||||
beam_sz = 5
|
|
||||||
n_words = 100
|
|
||||||
_non_eos_idxs = [47, 51, 13, 88, 99]
|
|
||||||
valid_score_dist = torch.log_softmax(torch.tensor(
|
|
||||||
[6., 5., 4., 3., 2., 1.]), dim=0)
|
|
||||||
min_length = 5
|
|
||||||
eos_idx = 2
|
|
||||||
beam = Beam(beam_sz, 0, 1, eos_idx, n_best=2,
|
|
||||||
exclusion_tokens=set(),
|
|
||||||
min_length=min_length,
|
|
||||||
global_scorer=GlobalScorerStub(),
|
|
||||||
block_ngram_repeat=0)
|
|
||||||
for i in range(min_length + 4):
|
|
||||||
# non-interesting beams are going to get dummy values
|
|
||||||
word_probs = torch.full((beam_sz, n_words), -float('inf'))
|
|
||||||
if i == 0:
|
|
||||||
# "best" prediction is eos - that should be blocked
|
|
||||||
word_probs[0, eos_idx] = valid_score_dist[0]
|
|
||||||
# include at least beam_sz predictions OTHER than EOS
|
|
||||||
# that are greater than -1e20
|
|
||||||
for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
|
|
||||||
word_probs[0, j] = score
|
|
||||||
elif i <= min_length:
|
|
||||||
# predict eos in beam 1
|
|
||||||
word_probs[1, eos_idx] = valid_score_dist[0]
|
|
||||||
# provide beam_sz other good predictions in other beams
|
|
||||||
for k, (j, score) in enumerate(
|
|
||||||
zip(_non_eos_idxs, valid_score_dist[1:])):
|
|
||||||
beam_idx = min(beam_sz-1, k)
|
|
||||||
word_probs[beam_idx, j] = score
|
|
||||||
else:
|
|
||||||
word_probs[0, eos_idx] = valid_score_dist[0]
|
|
||||||
word_probs[1, eos_idx] = valid_score_dist[0]
|
|
||||||
# provide beam_sz other good predictions in other beams
|
|
||||||
for k, (j, score) in enumerate(
|
|
||||||
zip(_non_eos_idxs, valid_score_dist[1:])):
|
|
||||||
beam_idx = min(beam_sz-1, k)
|
|
||||||
word_probs[beam_idx, j] = score
|
|
||||||
|
|
||||||
attns = torch.randn(beam_sz)
|
|
||||||
beam.advance(word_probs, attns)
|
|
||||||
if i < min_length:
|
|
||||||
self.assertFalse(beam.done)
|
|
||||||
elif i == min_length:
|
|
||||||
# beam 1 dies on min_length
|
|
||||||
self.assertEqual(beam.finished[0][1], beam.min_length + 1)
|
|
||||||
self.assertEqual(beam.finished[0][2], 1)
|
|
||||||
self.assertFalse(beam.done)
|
|
||||||
else: # i > min_length
|
|
||||||
# beam 0 dies on the step after beam 1 dies
|
|
||||||
self.assertEqual(beam.finished[1][1], beam.min_length + 2)
|
|
||||||
self.assertEqual(beam.finished[1][2], 0)
|
|
||||||
self.assertTrue(beam.done)
|
|
||||||
|
|
||||||
|
|
||||||
class TestBeamAgainstReferenceCase(unittest.TestCase):
|
|
||||||
BEAM_SZ = 5
|
|
||||||
EOS_IDX = 2 # don't change this - all the scores would need updated
|
|
||||||
N_WORDS = 8 # also don't change for same reason
|
|
||||||
N_BEST = 3
|
|
||||||
DEAD_SCORE = -1e20
|
|
||||||
INP_SEQ_LEN = 53
|
|
||||||
|
|
||||||
def init_step(self, beam):
|
|
||||||
# init_preds: [4, 3, 5, 6, 7] - no EOS's
|
|
||||||
init_scores = torch.log_softmax(torch.tensor(
|
|
||||||
[[0, 0, 0, 4, 5, 3, 2, 1]], dtype=torch.float), dim=1)
|
|
||||||
expected_beam_scores, expected_preds_0 = init_scores.topk(self.BEAM_SZ)
|
|
||||||
beam.advance(init_scores, torch.randn(self.BEAM_SZ, self.INP_SEQ_LEN))
|
|
||||||
self.assertTrue(beam.scores.allclose(expected_beam_scores))
|
|
||||||
self.assertTrue(beam.next_ys[-1].equal(expected_preds_0[0]))
|
|
||||||
self.assertFalse(beam.eos_top)
|
|
||||||
self.assertFalse(beam.done)
|
|
||||||
return expected_beam_scores
|
|
||||||
|
|
||||||
def first_step(self, beam, expected_beam_scores, expected_len_pen):
|
|
||||||
# no EOS's yet
|
|
||||||
assert len(beam.finished) == 0
|
|
||||||
scores_1 = torch.log_softmax(torch.tensor(
|
|
||||||
[[0, 0, 0, .3, 0, .51, .2, 0],
|
|
||||||
[0, 0, 1.5, 0, 0, 0, 0, 0],
|
|
||||||
[0, 0, 0, 0, .49, .48, 0, 0],
|
|
||||||
[0, 0, 0, .2, .2, .2, .2, .2],
|
|
||||||
[0, 0, 0, .2, .2, .2, .2, .2]]
|
|
||||||
), dim=1)
|
|
||||||
|
|
||||||
beam.advance(scores_1, torch.randn(self.BEAM_SZ, self.INP_SEQ_LEN))
|
|
||||||
|
|
||||||
new_scores = scores_1 + expected_beam_scores.t()
|
|
||||||
expected_beam_scores, unreduced_preds = new_scores.view(-1).topk(
|
|
||||||
self.BEAM_SZ, 0, True, True)
|
|
||||||
expected_bptr_1 = unreduced_preds / self.N_WORDS
|
|
||||||
# [5, 3, 2, 6, 0], so beam 2 predicts EOS!
|
|
||||||
expected_preds_1 = unreduced_preds - expected_bptr_1 * self.N_WORDS
|
|
||||||
|
|
||||||
self.assertTrue(beam.scores.allclose(expected_beam_scores))
|
|
||||||
self.assertTrue(beam.next_ys[-1].equal(expected_preds_1))
|
|
||||||
self.assertTrue(beam.prev_ks[-1].equal(expected_bptr_1))
|
|
||||||
self.assertEqual(len(beam.finished), 1)
|
|
||||||
self.assertEqual(beam.finished[0][2], 2) # beam 2 finished
|
|
||||||
self.assertEqual(beam.finished[0][1], 2) # finished on second step
|
|
||||||
self.assertEqual(beam.finished[0][0], # finished with correct score
|
|
||||||
expected_beam_scores[2] / expected_len_pen)
|
|
||||||
self.assertFalse(beam.eos_top)
|
|
||||||
self.assertFalse(beam.done)
|
|
||||||
return expected_beam_scores
|
|
||||||
|
|
||||||
def second_step(self, beam, expected_beam_scores, expected_len_pen):
|
|
||||||
# assumes beam 2 finished on last step
|
|
||||||
scores_2 = torch.log_softmax(torch.tensor(
|
|
||||||
[[0, 0, 0, .3, 0, .51, .2, 0],
|
|
||||||
[0, 0, 0, 0, 0, 0, 0, 0],
|
|
||||||
[0, 0, 0, 0, 5000, .48, 0, 0], # beam 2 shouldn't continue
|
|
||||||
[0, 0, 50, .2, .2, .2, .2, .2], # beam 3 -> beam 0 should die
|
|
||||||
[0, 0, 0, .2, .2, .2, .2, .2]]
|
|
||||||
), dim=1)
|
|
||||||
|
|
||||||
beam.advance(scores_2, torch.randn(self.BEAM_SZ, self.INP_SEQ_LEN))
|
|
||||||
|
|
||||||
new_scores = scores_2 + expected_beam_scores.unsqueeze(1)
|
|
||||||
new_scores[2] = self.DEAD_SCORE # ended beam 2 shouldn't continue
|
|
||||||
expected_beam_scores, unreduced_preds = new_scores.view(-1).topk(
|
|
||||||
self.BEAM_SZ, 0, True, True)
|
|
||||||
expected_bptr_2 = unreduced_preds / self.N_WORDS
|
|
||||||
# [2, 5, 3, 6, 0], so beam 0 predicts EOS!
|
|
||||||
expected_preds_2 = unreduced_preds - expected_bptr_2 * self.N_WORDS
|
|
||||||
# [-2.4879, -3.8910, -4.1010, -4.2010, -4.4010]
|
|
||||||
self.assertTrue(beam.scores.allclose(expected_beam_scores))
|
|
||||||
self.assertTrue(beam.next_ys[-1].equal(expected_preds_2))
|
|
||||||
self.assertTrue(beam.prev_ks[-1].equal(expected_bptr_2))
|
|
||||||
self.assertEqual(len(beam.finished), 2)
|
|
||||||
# new beam 0 finished
|
|
||||||
self.assertEqual(beam.finished[1][2], 0)
|
|
||||||
# new beam 0 is old beam 3
|
|
||||||
self.assertEqual(expected_bptr_2[0], 3)
|
|
||||||
self.assertEqual(beam.finished[1][1], 3) # finished on third step
|
|
||||||
self.assertEqual(beam.finished[1][0], # finished with correct score
|
|
||||||
expected_beam_scores[0] / expected_len_pen)
|
|
||||||
self.assertTrue(beam.eos_top)
|
|
||||||
self.assertFalse(beam.done)
|
|
||||||
return expected_beam_scores
|
|
||||||
|
|
||||||
def third_step(self, beam, expected_beam_scores, expected_len_pen):
|
|
||||||
# assumes beam 0 finished on last step
|
|
||||||
scores_3 = torch.log_softmax(torch.tensor(
|
|
||||||
[[0, 0, 5000, 0, 5000, .51, .2, 0], # beam 0 shouldn't cont
|
|
||||||
[0, 0, 0, 0, 0, 0, 0, 0],
|
|
||||||
[0, 0, 0, 0, 0, 5000, 0, 0],
|
|
||||||
[0, 0, 0, .2, .2, .2, .2, .2],
|
|
||||||
[0, 0, 50, 0, .2, .2, .2, .2]] # beam 4 -> beam 1 should die
|
|
||||||
), dim=1)
|
|
||||||
|
|
||||||
beam.advance(scores_3, torch.randn(self.BEAM_SZ, self.INP_SEQ_LEN))
|
|
||||||
|
|
||||||
new_scores = scores_3 + expected_beam_scores.unsqueeze(1)
|
|
||||||
new_scores[0] = self.DEAD_SCORE # ended beam 2 shouldn't continue
|
|
||||||
expected_beam_scores, unreduced_preds = new_scores.view(-1).topk(
|
|
||||||
self.BEAM_SZ, 0, True, True)
|
|
||||||
expected_bptr_3 = unreduced_preds / self.N_WORDS
|
|
||||||
# [5, 2, 6, 1, 0], so beam 1 predicts EOS!
|
|
||||||
expected_preds_3 = unreduced_preds - expected_bptr_3 * self.N_WORDS
|
|
||||||
self.assertTrue(beam.scores.allclose(expected_beam_scores))
|
|
||||||
self.assertTrue(beam.next_ys[-1].equal(expected_preds_3))
|
|
||||||
self.assertTrue(beam.prev_ks[-1].equal(expected_bptr_3))
|
|
||||||
self.assertEqual(len(beam.finished), 3)
|
|
||||||
# new beam 1 finished
|
|
||||||
self.assertEqual(beam.finished[2][2], 1)
|
|
||||||
# new beam 1 is old beam 4
|
|
||||||
self.assertEqual(expected_bptr_3[1], 4)
|
|
||||||
self.assertEqual(beam.finished[2][1], 4) # finished on fourth step
|
|
||||||
self.assertEqual(beam.finished[2][0], # finished with correct score
|
|
||||||
expected_beam_scores[1] / expected_len_pen)
|
|
||||||
self.assertTrue(beam.eos_top)
|
|
||||||
self.assertTrue(beam.done)
|
|
||||||
return expected_beam_scores
|
|
||||||
|
|
||||||
def test_beam_advance_against_known_reference(self):
|
|
||||||
beam = Beam(self.BEAM_SZ, 0, 1, self.EOS_IDX, n_best=self.N_BEST,
|
|
||||||
exclusion_tokens=set(),
|
|
||||||
min_length=0,
|
|
||||||
global_scorer=GlobalScorerStub(),
|
|
||||||
block_ngram_repeat=0)
|
|
||||||
|
|
||||||
expected_beam_scores = self.init_step(beam)
|
|
||||||
expected_beam_scores = self.first_step(beam, expected_beam_scores, 1)
|
|
||||||
expected_beam_scores = self.second_step(beam, expected_beam_scores, 1)
|
|
||||||
self.third_step(beam, expected_beam_scores, 1)
|
|
||||||
|
|
||||||
|
|
||||||
class TestBeamWithLengthPenalty(TestBeamAgainstReferenceCase):
|
|
||||||
# this could be considered an integration test because it tests
|
|
||||||
# interactions between the GNMT scorer and the beam
|
|
||||||
|
|
||||||
def test_beam_advance_against_known_reference(self):
|
|
||||||
scorer = GNMTGlobalScorer(0.7, 0., "avg", "none")
|
|
||||||
beam = Beam(self.BEAM_SZ, 0, 1, self.EOS_IDX, n_best=self.N_BEST,
|
|
||||||
exclusion_tokens=set(),
|
|
||||||
min_length=0,
|
|
||||||
global_scorer=scorer,
|
|
||||||
block_ngram_repeat=0)
|
|
||||||
expected_beam_scores = self.init_step(beam)
|
|
||||||
expected_beam_scores = self.first_step(beam, expected_beam_scores, 3)
|
|
||||||
expected_beam_scores = self.second_step(beam, expected_beam_scores, 4)
|
|
||||||
self.third_step(beam, expected_beam_scores, 5)
|
|
|
@ -1,6 +1,5 @@
|
||||||
import unittest
|
import unittest
|
||||||
from onmt.translate.beam import GNMTGlobalScorer
|
from onmt.translate.beam_search import BeamSearch, GNMTGlobalScorer
|
||||||
from onmt.translate.beam_search import BeamSearch
|
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
@ -34,29 +33,49 @@ class TestBeamSearch(unittest.TestCase):
|
||||||
n_words = 100
|
n_words = 100
|
||||||
repeat_idx = 47
|
repeat_idx = 47
|
||||||
ngram_repeat = 3
|
ngram_repeat = 3
|
||||||
|
device_init = torch.zeros(1, 1)
|
||||||
for batch_sz in [1, 3]:
|
for batch_sz in [1, 3]:
|
||||||
beam = BeamSearch(
|
beam = BeamSearch(
|
||||||
beam_sz, batch_sz, 0, 1, 2, 2,
|
beam_sz, batch_sz, 0, 1, 2, 2,
|
||||||
torch.device("cpu"), GlobalScorerStub(), 0, 30,
|
GlobalScorerStub(), 0, 30,
|
||||||
False, ngram_repeat, set(),
|
False, ngram_repeat, set(),
|
||||||
torch.randint(0, 30, (batch_sz,)), False, 0.)
|
False, 0.)
|
||||||
|
beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
|
||||||
for i in range(ngram_repeat + 4):
|
for i in range(ngram_repeat + 4):
|
||||||
# predict repeat_idx over and over again
|
# predict repeat_idx over and over again
|
||||||
word_probs = torch.full(
|
word_probs = torch.full(
|
||||||
(batch_sz * beam_sz, n_words), -float('inf'))
|
(batch_sz * beam_sz, n_words), -float('inf'))
|
||||||
word_probs[0::beam_sz, repeat_idx] = 0
|
word_probs[0::beam_sz, repeat_idx] = 0
|
||||||
|
|
||||||
attns = torch.randn(1, batch_sz * beam_sz, 53)
|
attns = torch.randn(1, batch_sz * beam_sz, 53)
|
||||||
beam.advance(word_probs, attns)
|
beam.advance(word_probs, attns)
|
||||||
if i <= ngram_repeat:
|
|
||||||
|
if i < ngram_repeat:
|
||||||
|
# before repeat, scores are either 0 or -inf
|
||||||
expected_scores = torch.tensor(
|
expected_scores = torch.tensor(
|
||||||
[0] + [-float('inf')] * (beam_sz - 1))\
|
[0] + [-float('inf')] * (beam_sz - 1))\
|
||||||
.repeat(batch_sz, 1)
|
.repeat(batch_sz, 1)
|
||||||
self.assertTrue(beam.topk_log_probs.equal(expected_scores))
|
self.assertTrue(beam.topk_log_probs.equal(expected_scores))
|
||||||
|
elif i % ngram_repeat == 0:
|
||||||
|
# on repeat, `repeat_idx` score is BLOCKED_SCORE
|
||||||
|
# (but it's still the best score, thus we have
|
||||||
|
# [BLOCKED_SCORE, -inf, -inf, -inf, -inf]
|
||||||
|
expected_scores = torch.tensor(
|
||||||
|
[0] + [-float('inf')] * (beam_sz - 1))\
|
||||||
|
.repeat(batch_sz, 1)
|
||||||
|
expected_scores[:, 0] = self.BLOCKED_SCORE
|
||||||
|
self.assertTrue(beam.topk_log_probs.equal(expected_scores))
|
||||||
else:
|
else:
|
||||||
self.assertTrue(
|
# repetitions keeps maximizing score
|
||||||
beam.topk_log_probs.equal(
|
# index 0 has been blocked, so repeating=>+0.0 score
|
||||||
torch.tensor(self.BLOCKED_SCORE)
|
# other indexes are -inf so repeating=>BLOCKED_SCORE
|
||||||
.repeat(batch_sz, beam_sz)))
|
# which is higher
|
||||||
|
expected_scores = torch.tensor(
|
||||||
|
[0] + [-float('inf')] * (beam_sz - 1))\
|
||||||
|
.repeat(batch_sz, 1)
|
||||||
|
expected_scores[:, :] = self.BLOCKED_SCORE
|
||||||
|
expected_scores = torch.tensor(
|
||||||
|
self.BLOCKED_SCORE).repeat(batch_sz, beam_sz)
|
||||||
|
|
||||||
def test_advance_with_some_repeats_gets_blocked(self):
|
def test_advance_with_some_repeats_gets_blocked(self):
|
||||||
# beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
|
# beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
|
||||||
|
@ -64,12 +83,16 @@ class TestBeamSearch(unittest.TestCase):
|
||||||
n_words = 100
|
n_words = 100
|
||||||
repeat_idx = 47
|
repeat_idx = 47
|
||||||
ngram_repeat = 3
|
ngram_repeat = 3
|
||||||
|
no_repeat_score = -2.3
|
||||||
|
repeat_score = -0.1
|
||||||
|
device_init = torch.zeros(1, 1)
|
||||||
for batch_sz in [1, 3]:
|
for batch_sz in [1, 3]:
|
||||||
beam = BeamSearch(
|
beam = BeamSearch(
|
||||||
beam_sz, batch_sz, 0, 1, 2, 2,
|
beam_sz, batch_sz, 0, 1, 2, 2,
|
||||||
torch.device("cpu"), GlobalScorerStub(), 0, 30,
|
GlobalScorerStub(), 0, 30,
|
||||||
False, ngram_repeat, set(),
|
False, ngram_repeat, set(),
|
||||||
torch.randint(0, 30, (batch_sz,)), False, 0.)
|
False, 0.)
|
||||||
|
beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
|
||||||
for i in range(ngram_repeat + 4):
|
for i in range(ngram_repeat + 4):
|
||||||
# non-interesting beams are going to get dummy values
|
# non-interesting beams are going to get dummy values
|
||||||
word_probs = torch.full(
|
word_probs = torch.full(
|
||||||
|
@ -78,8 +101,9 @@ class TestBeamSearch(unittest.TestCase):
|
||||||
# on initial round, only predicted scores for beam 0
|
# on initial round, only predicted scores for beam 0
|
||||||
# matter. Make two predictions. Top one will be repeated
|
# matter. Make two predictions. Top one will be repeated
|
||||||
# in beam zero, second one will live on in beam 1.
|
# in beam zero, second one will live on in beam 1.
|
||||||
word_probs[0::beam_sz, repeat_idx] = -0.1
|
word_probs[0::beam_sz, repeat_idx] = repeat_score
|
||||||
word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3
|
word_probs[0::beam_sz, repeat_idx +
|
||||||
|
i + 1] = no_repeat_score
|
||||||
else:
|
else:
|
||||||
# predict the same thing in beam 0
|
# predict the same thing in beam 0
|
||||||
word_probs[0::beam_sz, repeat_idx] = 0
|
word_probs[0::beam_sz, repeat_idx] = 0
|
||||||
|
@ -87,22 +111,35 @@ class TestBeamSearch(unittest.TestCase):
|
||||||
word_probs[1::beam_sz, repeat_idx + i + 1] = 0
|
word_probs[1::beam_sz, repeat_idx + i + 1] = 0
|
||||||
attns = torch.randn(1, batch_sz * beam_sz, 53)
|
attns = torch.randn(1, batch_sz * beam_sz, 53)
|
||||||
beam.advance(word_probs, attns)
|
beam.advance(word_probs, attns)
|
||||||
if i <= ngram_repeat:
|
if i < ngram_repeat:
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
beam.topk_log_probs[0::beam_sz].eq(
|
beam.topk_log_probs[0::beam_sz].eq(
|
||||||
self.BLOCKED_SCORE).any())
|
self.BLOCKED_SCORE).any())
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
beam.topk_log_probs[1::beam_sz].eq(
|
beam.topk_log_probs[1::beam_sz].eq(
|
||||||
self.BLOCKED_SCORE).any())
|
self.BLOCKED_SCORE).any())
|
||||||
|
elif i == ngram_repeat:
|
||||||
|
# now beam 0 dies (along with the others), beam 1 -> beam 0
|
||||||
|
self.assertFalse(
|
||||||
|
beam.topk_log_probs[:, 0].eq(
|
||||||
|
self.BLOCKED_SCORE).any())
|
||||||
|
|
||||||
|
expected = torch.full([batch_sz, beam_sz], float("-inf"))
|
||||||
|
expected[:, 0] = no_repeat_score
|
||||||
|
expected[:, 1] = self.BLOCKED_SCORE
|
||||||
|
self.assertTrue(
|
||||||
|
beam.topk_log_probs[:, :].equal(expected))
|
||||||
else:
|
else:
|
||||||
# now beam 0 dies (along with the others), beam 1 -> beam 0
|
# now beam 0 dies (along with the others), beam 1 -> beam 0
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
beam.topk_log_probs[:, 0].eq(
|
beam.topk_log_probs[:, 0].eq(
|
||||||
self.BLOCKED_SCORE).any())
|
self.BLOCKED_SCORE).any())
|
||||||
|
|
||||||
|
expected = torch.full([batch_sz, beam_sz], float("-inf"))
|
||||||
|
expected[:, 0] = no_repeat_score
|
||||||
|
expected[:, 1:] = self.BLOCKED_SCORE
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
beam.topk_log_probs[:, 1:].equal(
|
beam.topk_log_probs.equal(expected))
|
||||||
torch.tensor(self.BLOCKED_SCORE)
|
|
||||||
.repeat(batch_sz, beam_sz-1)))
|
|
||||||
|
|
||||||
def test_repeating_excluded_index_does_not_die(self):
|
def test_repeating_excluded_index_does_not_die(self):
|
||||||
# beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx)
|
# beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx)
|
||||||
|
@ -111,12 +148,14 @@ class TestBeamSearch(unittest.TestCase):
|
||||||
repeat_idx = 47 # will be repeated and should be blocked
|
repeat_idx = 47 # will be repeated and should be blocked
|
||||||
repeat_idx_ignored = 7 # will be repeated and should not be blocked
|
repeat_idx_ignored = 7 # will be repeated and should not be blocked
|
||||||
ngram_repeat = 3
|
ngram_repeat = 3
|
||||||
|
device_init = torch.zeros(1, 1)
|
||||||
for batch_sz in [1, 3]:
|
for batch_sz in [1, 3]:
|
||||||
beam = BeamSearch(
|
beam = BeamSearch(
|
||||||
beam_sz, batch_sz, 0, 1, 2, 2,
|
beam_sz, batch_sz, 0, 1, 2, 2,
|
||||||
torch.device("cpu"), GlobalScorerStub(), 0, 30,
|
GlobalScorerStub(), 0, 30,
|
||||||
False, ngram_repeat, {repeat_idx_ignored},
|
False, ngram_repeat, {repeat_idx_ignored},
|
||||||
torch.randint(0, 30, (batch_sz,)), False, 0.)
|
False, 0.)
|
||||||
|
beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
|
||||||
for i in range(ngram_repeat + 4):
|
for i in range(ngram_repeat + 4):
|
||||||
# non-interesting beams are going to get dummy values
|
# non-interesting beams are going to get dummy values
|
||||||
word_probs = torch.full(
|
word_probs = torch.full(
|
||||||
|
@ -134,7 +173,7 @@ class TestBeamSearch(unittest.TestCase):
|
||||||
word_probs[2::beam_sz, repeat_idx_ignored] = 0
|
word_probs[2::beam_sz, repeat_idx_ignored] = 0
|
||||||
attns = torch.randn(1, batch_sz * beam_sz, 53)
|
attns = torch.randn(1, batch_sz * beam_sz, 53)
|
||||||
beam.advance(word_probs, attns)
|
beam.advance(word_probs, attns)
|
||||||
if i <= ngram_repeat:
|
if i < ngram_repeat:
|
||||||
self.assertFalse(beam.topk_log_probs[:, 0].eq(
|
self.assertFalse(beam.topk_log_probs[:, 0].eq(
|
||||||
self.BLOCKED_SCORE).any())
|
self.BLOCKED_SCORE).any())
|
||||||
self.assertFalse(beam.topk_log_probs[:, 1].eq(
|
self.assertFalse(beam.topk_log_probs[:, 1].eq(
|
||||||
|
@ -153,10 +192,9 @@ class TestBeamSearch(unittest.TestCase):
|
||||||
self.assertFalse(beam.topk_log_probs[:, 1].eq(
|
self.assertFalse(beam.topk_log_probs[:, 1].eq(
|
||||||
self.BLOCKED_SCORE).all())
|
self.BLOCKED_SCORE).all())
|
||||||
self.assertTrue(beam.topk_log_probs[:, 1].eq(-5.0).all())
|
self.assertTrue(beam.topk_log_probs[:, 1].eq(-5.0).all())
|
||||||
self.assertTrue(
|
|
||||||
beam.topk_log_probs[:, 2:].equal(
|
self.assertTrue(beam.topk_log_probs[:, 2].eq(
|
||||||
torch.tensor(self.BLOCKED_SCORE)
|
self.BLOCKED_SCORE).all())
|
||||||
.repeat(batch_sz, beam_sz - 2)))
|
|
||||||
|
|
||||||
def test_doesnt_predict_eos_if_shorter_than_min_len(self):
|
def test_doesnt_predict_eos_if_shorter_than_min_len(self):
|
||||||
# beam 0 will always predict EOS. The other beams will predict
|
# beam 0 will always predict EOS. The other beams will predict
|
||||||
|
@ -171,9 +209,11 @@ class TestBeamSearch(unittest.TestCase):
|
||||||
eos_idx = 2
|
eos_idx = 2
|
||||||
lengths = torch.randint(0, 30, (batch_sz,))
|
lengths = torch.randint(0, 30, (batch_sz,))
|
||||||
beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
|
beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
|
||||||
torch.device("cpu"), GlobalScorerStub(),
|
GlobalScorerStub(),
|
||||||
min_length, 30, False, 0, set(),
|
min_length, 30, False, 0, set(),
|
||||||
lengths, False, 0.)
|
False, 0.)
|
||||||
|
device_init = torch.zeros(1, 1)
|
||||||
|
beam.initialize(device_init, lengths)
|
||||||
all_attns = []
|
all_attns = []
|
||||||
for i in range(min_length + 4):
|
for i in range(min_length + 4):
|
||||||
# non-interesting beams are going to get dummy values
|
# non-interesting beams are going to get dummy values
|
||||||
|
@ -226,9 +266,11 @@ class TestBeamSearch(unittest.TestCase):
|
||||||
eos_idx = 2
|
eos_idx = 2
|
||||||
beam = BeamSearch(
|
beam = BeamSearch(
|
||||||
beam_sz, batch_sz, 0, 1, 2, 2,
|
beam_sz, batch_sz, 0, 1, 2, 2,
|
||||||
torch.device("cpu"), GlobalScorerStub(),
|
GlobalScorerStub(),
|
||||||
min_length, 30, False, 0, set(),
|
min_length, 30, False, 0, set(),
|
||||||
torch.randint(0, 30, (batch_sz,)), False, 0.)
|
False, 0.)
|
||||||
|
device_init = torch.zeros(1, 1)
|
||||||
|
beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
|
||||||
for i in range(min_length + 4):
|
for i in range(min_length + 4):
|
||||||
# non-interesting beams are going to get dummy values
|
# non-interesting beams are going to get dummy values
|
||||||
word_probs = torch.full(
|
word_probs = torch.full(
|
||||||
|
@ -284,9 +326,12 @@ class TestBeamSearch(unittest.TestCase):
|
||||||
inp_lens = torch.randint(1, 30, (batch_sz,))
|
inp_lens = torch.randint(1, 30, (batch_sz,))
|
||||||
beam = BeamSearch(
|
beam = BeamSearch(
|
||||||
beam_sz, batch_sz, 0, 1, 2, 2,
|
beam_sz, batch_sz, 0, 1, 2, 2,
|
||||||
torch.device("cpu"), GlobalScorerStub(),
|
GlobalScorerStub(),
|
||||||
min_length, 30, True, 0, set(),
|
min_length, 30, True, 0, set(),
|
||||||
inp_lens, False, 0.)
|
False, 0.)
|
||||||
|
device_init = torch.zeros(1, 1)
|
||||||
|
_, _, inp_lens, _ = beam.initialize(device_init, inp_lens)
|
||||||
|
# inp_lens is tiled in initialize, reassign to make attn match
|
||||||
for i in range(min_length + 2):
|
for i in range(min_length + 2):
|
||||||
# non-interesting beams are going to get dummy values
|
# non-interesting beams are going to get dummy values
|
||||||
word_probs = torch.full(
|
word_probs = torch.full(
|
||||||
|
@ -495,10 +540,11 @@ class TestBeamSearchAgainstReferenceCase(unittest.TestCase):
|
||||||
def test_beam_advance_against_known_reference(self):
|
def test_beam_advance_against_known_reference(self):
|
||||||
beam = BeamSearch(
|
beam = BeamSearch(
|
||||||
self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
|
self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
|
||||||
torch.device("cpu"), GlobalScorerStub(),
|
GlobalScorerStub(),
|
||||||
0, 30, False, 0, set(),
|
0, 30, False, 0, set(),
|
||||||
torch.randint(0, 30, (self.BATCH_SZ,)), False, 0.)
|
False, 0.)
|
||||||
|
device_init = torch.zeros(1, 1)
|
||||||
|
beam.initialize(device_init, torch.randint(0, 30, (self.BATCH_SZ,)))
|
||||||
expected_beam_scores = self.init_step(beam, 1)
|
expected_beam_scores = self.init_step(beam, 1)
|
||||||
expected_beam_scores = self.first_step(beam, expected_beam_scores, 1)
|
expected_beam_scores = self.first_step(beam, expected_beam_scores, 1)
|
||||||
expected_beam_scores = self.second_step(beam, expected_beam_scores, 1)
|
expected_beam_scores = self.second_step(beam, expected_beam_scores, 1)
|
||||||
|
@ -513,9 +559,11 @@ class TestBeamWithLengthPenalty(TestBeamSearchAgainstReferenceCase):
|
||||||
scorer = GNMTGlobalScorer(0.7, 0., "avg", "none")
|
scorer = GNMTGlobalScorer(0.7, 0., "avg", "none")
|
||||||
beam = BeamSearch(
|
beam = BeamSearch(
|
||||||
self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
|
self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
|
||||||
torch.device("cpu"), scorer,
|
scorer,
|
||||||
0, 30, False, 0, set(),
|
0, 30, False, 0, set(),
|
||||||
torch.randint(0, 30, (self.BATCH_SZ,)), False, 0.)
|
False, 0.)
|
||||||
|
device_init = torch.zeros(1, 1)
|
||||||
|
beam.initialize(device_init, torch.randint(0, 30, (self.BATCH_SZ,)))
|
||||||
expected_beam_scores = self.init_step(beam, 1.)
|
expected_beam_scores = self.init_step(beam, 1.)
|
||||||
expected_beam_scores = self.first_step(beam, expected_beam_scores, 3)
|
expected_beam_scores = self.first_step(beam, expected_beam_scores, 3)
|
||||||
expected_beam_scores = self.second_step(beam, expected_beam_scores, 4)
|
expected_beam_scores = self.second_step(beam, expected_beam_scores, 4)
|
||||||
|
|
|
@ -1,114 +1,16 @@
|
||||||
import unittest
|
import unittest
|
||||||
from onmt.translate.random_sampling import RandomSampling
|
from onmt.translate.greedy_search import GreedySearch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class TestRandomSampling(unittest.TestCase):
|
class TestGreedySearch(unittest.TestCase):
|
||||||
BATCH_SZ = 3
|
BATCH_SZ = 3
|
||||||
INP_SEQ_LEN = 53
|
INP_SEQ_LEN = 53
|
||||||
DEAD_SCORE = -1e20
|
DEAD_SCORE = -1e20
|
||||||
|
|
||||||
BLOCKED_SCORE = -10e20
|
BLOCKED_SCORE = -10e20
|
||||||
|
|
||||||
def test_advance_with_repeats_gets_blocked(self):
|
|
||||||
n_words = 100
|
|
||||||
repeat_idx = 47
|
|
||||||
ngram_repeat = 3
|
|
||||||
for batch_sz in [1, 3]:
|
|
||||||
samp = RandomSampling(
|
|
||||||
0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat, set(),
|
|
||||||
False, 30, 1., 5, torch.randint(0, 30, (batch_sz,)))
|
|
||||||
for i in range(ngram_repeat + 4):
|
|
||||||
# predict repeat_idx over and over again
|
|
||||||
word_probs = torch.full(
|
|
||||||
(batch_sz, n_words), -float('inf'))
|
|
||||||
word_probs[:, repeat_idx] = 0
|
|
||||||
attns = torch.randn(1, batch_sz, 53)
|
|
||||||
samp.advance(word_probs, attns)
|
|
||||||
if i <= ngram_repeat:
|
|
||||||
expected_scores = torch.zeros((batch_sz, 1))
|
|
||||||
self.assertTrue(samp.topk_scores.equal(expected_scores))
|
|
||||||
else:
|
|
||||||
self.assertTrue(
|
|
||||||
samp.topk_scores.equal(
|
|
||||||
torch.tensor(self.BLOCKED_SCORE)
|
|
||||||
.repeat(batch_sz, 1)))
|
|
||||||
|
|
||||||
def test_advance_with_some_repeats_gets_blocked(self):
|
|
||||||
# batch 0 and 7 will repeat, the rest will advance
|
|
||||||
n_words = 100
|
|
||||||
repeat_idx = 47
|
|
||||||
other_repeat_idx = 12
|
|
||||||
ngram_repeat = 3
|
|
||||||
for batch_sz in [1, 3, 13]:
|
|
||||||
samp = RandomSampling(
|
|
||||||
0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat, set(),
|
|
||||||
False, 30, 1., 5, torch.randint(0, 30, (batch_sz,)))
|
|
||||||
for i in range(ngram_repeat + 4):
|
|
||||||
word_probs = torch.full(
|
|
||||||
(batch_sz, n_words), -float('inf'))
|
|
||||||
# predict the same thing in batch 0 and 7 every i
|
|
||||||
word_probs[0, repeat_idx] = 0
|
|
||||||
if batch_sz > 7:
|
|
||||||
word_probs[7, other_repeat_idx] = 0
|
|
||||||
# push around what the other batches predict
|
|
||||||
word_probs[1:7, repeat_idx + i] = 0
|
|
||||||
if batch_sz > 7:
|
|
||||||
word_probs[8:, repeat_idx + i] = 0
|
|
||||||
attns = torch.randn(1, batch_sz, 53)
|
|
||||||
samp.advance(word_probs, attns)
|
|
||||||
if i <= ngram_repeat:
|
|
||||||
self.assertFalse(
|
|
||||||
samp.topk_scores.eq(
|
|
||||||
self.BLOCKED_SCORE).any())
|
|
||||||
else:
|
|
||||||
# now batch 0 and 7 die
|
|
||||||
self.assertTrue(samp.topk_scores[0].eq(self.BLOCKED_SCORE))
|
|
||||||
if batch_sz > 7:
|
|
||||||
self.assertTrue(samp.topk_scores[7].eq(
|
|
||||||
self.BLOCKED_SCORE))
|
|
||||||
self.assertFalse(
|
|
||||||
samp.topk_scores[1:7].eq(
|
|
||||||
self.BLOCKED_SCORE).any())
|
|
||||||
if batch_sz > 7:
|
|
||||||
self.assertFalse(
|
|
||||||
samp.topk_scores[8:].eq(
|
|
||||||
self.BLOCKED_SCORE).any())
|
|
||||||
|
|
||||||
def test_repeating_excluded_index_does_not_die(self):
|
|
||||||
# batch 0 will repeat excluded idx, batch 1 will repeat
|
|
||||||
n_words = 100
|
|
||||||
repeat_idx = 47 # will be repeated and should be blocked
|
|
||||||
repeat_idx_ignored = 7 # will be repeated and should not be blocked
|
|
||||||
ngram_repeat = 3
|
|
||||||
for batch_sz in [1, 3, 17]:
|
|
||||||
samp = RandomSampling(
|
|
||||||
0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat,
|
|
||||||
{repeat_idx_ignored}, False, 30, 1., 5,
|
|
||||||
torch.randint(0, 30, (batch_sz,)))
|
|
||||||
for i in range(ngram_repeat + 4):
|
|
||||||
word_probs = torch.full(
|
|
||||||
(batch_sz, n_words), -float('inf'))
|
|
||||||
word_probs[0, repeat_idx_ignored] = 0
|
|
||||||
if batch_sz > 1:
|
|
||||||
word_probs[1, repeat_idx] = 0
|
|
||||||
word_probs[2:, repeat_idx + i] = 0
|
|
||||||
attns = torch.randn(1, batch_sz, 53)
|
|
||||||
samp.advance(word_probs, attns)
|
|
||||||
if i <= ngram_repeat:
|
|
||||||
self.assertFalse(samp.topk_scores.eq(
|
|
||||||
self.BLOCKED_SCORE).any())
|
|
||||||
else:
|
|
||||||
# now batch 1 dies
|
|
||||||
self.assertFalse(samp.topk_scores[0].eq(
|
|
||||||
self.BLOCKED_SCORE).any())
|
|
||||||
if batch_sz > 1:
|
|
||||||
self.assertTrue(samp.topk_scores[1].eq(
|
|
||||||
self.BLOCKED_SCORE).all())
|
|
||||||
self.assertFalse(samp.topk_scores[2:].eq(
|
|
||||||
self.BLOCKED_SCORE).any())
|
|
||||||
|
|
||||||
def test_doesnt_predict_eos_if_shorter_than_min_len(self):
|
def test_doesnt_predict_eos_if_shorter_than_min_len(self):
|
||||||
# batch 0 will always predict EOS. The other batches will predict
|
# batch 0 will always predict EOS. The other batches will predict
|
||||||
# non-eos scores.
|
# non-eos scores.
|
||||||
|
@ -120,9 +22,10 @@ class TestRandomSampling(unittest.TestCase):
|
||||||
min_length = 5
|
min_length = 5
|
||||||
eos_idx = 2
|
eos_idx = 2
|
||||||
lengths = torch.randint(0, 30, (batch_sz,))
|
lengths = torch.randint(0, 30, (batch_sz,))
|
||||||
samp = RandomSampling(
|
samp = GreedySearch(
|
||||||
0, 1, 2, batch_sz, torch.device("cpu"), min_length,
|
0, 1, 2, batch_sz, min_length,
|
||||||
False, set(), False, 30, 1., 1, lengths)
|
False, set(), False, 30, 1., 1)
|
||||||
|
samp.initialize(torch.zeros(1), lengths)
|
||||||
all_attns = []
|
all_attns = []
|
||||||
for i in range(min_length + 4):
|
for i in range(min_length + 4):
|
||||||
word_probs = torch.full(
|
word_probs = torch.full(
|
||||||
|
@ -160,10 +63,10 @@ class TestRandomSampling(unittest.TestCase):
|
||||||
[6., 1.]), dim=0)
|
[6., 1.]), dim=0)
|
||||||
eos_idx = 2
|
eos_idx = 2
|
||||||
lengths = torch.randint(0, 30, (batch_sz,))
|
lengths = torch.randint(0, 30, (batch_sz,))
|
||||||
samp = RandomSampling(
|
samp = GreedySearch(
|
||||||
0, 1, 2, batch_sz, torch.device("cpu"), 0,
|
0, 1, 2, batch_sz, 0,
|
||||||
False, set(), False, 30, temp, 1, lengths)
|
False, set(), False, 30, temp, 1)
|
||||||
|
samp.initialize(torch.zeros(1), lengths)
|
||||||
# initial step
|
# initial step
|
||||||
i = 0
|
i = 0
|
||||||
word_probs = torch.full(
|
word_probs = torch.full(
|
||||||
|
@ -232,10 +135,10 @@ class TestRandomSampling(unittest.TestCase):
|
||||||
[6., 1.]), dim=0)
|
[6., 1.]), dim=0)
|
||||||
eos_idx = 2
|
eos_idx = 2
|
||||||
lengths = torch.randint(0, 30, (batch_sz,))
|
lengths = torch.randint(0, 30, (batch_sz,))
|
||||||
samp = RandomSampling(
|
samp = GreedySearch(
|
||||||
0, 1, 2, batch_sz, torch.device("cpu"), 0,
|
0, 1, 2, batch_sz, 0,
|
||||||
False, set(), False, 30, temp, 2, lengths)
|
False, set(), False, 30, temp, 2)
|
||||||
|
samp.initialize(torch.zeros(1), lengths)
|
||||||
# initial step
|
# initial step
|
||||||
i = 0
|
i = 0
|
||||||
for _ in range(100):
|
for _ in range(100):
|
|
@ -12,7 +12,7 @@ import codecs
|
||||||
import onmt
|
import onmt
|
||||||
import onmt.inputters
|
import onmt.inputters
|
||||||
import onmt.opts
|
import onmt.opts
|
||||||
import preprocess
|
import onmt.bin.preprocess as preprocess
|
||||||
|
|
||||||
|
|
||||||
parser = configargparse.ArgumentParser(description='preprocess.py')
|
parser = configargparse.ArgumentParser(description='preprocess.py')
|
||||||
|
@ -49,11 +49,12 @@ class TestData(unittest.TestCase):
|
||||||
|
|
||||||
src_reader = onmt.inputters.str2reader[opt.data_type].from_opt(opt)
|
src_reader = onmt.inputters.str2reader[opt.data_type].from_opt(opt)
|
||||||
tgt_reader = onmt.inputters.str2reader["text"].from_opt(opt)
|
tgt_reader = onmt.inputters.str2reader["text"].from_opt(opt)
|
||||||
|
align_reader = onmt.inputters.str2reader["text"].from_opt(opt)
|
||||||
preprocess.build_save_dataset(
|
preprocess.build_save_dataset(
|
||||||
'train', fields, src_reader, tgt_reader, opt)
|
'train', fields, src_reader, tgt_reader, align_reader, opt)
|
||||||
|
|
||||||
preprocess.build_save_dataset(
|
preprocess.build_save_dataset(
|
||||||
'valid', fields, src_reader, tgt_reader, opt)
|
'valid', fields, src_reader, tgt_reader, align_reader, opt)
|
||||||
|
|
||||||
# Remove the generated *pt files.
|
# Remove the generated *pt files.
|
||||||
for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'):
|
for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'):
|
||||||
|
|
|
@ -112,27 +112,24 @@ class TestServerModel(unittest.TestCase):
|
||||||
sm = ServerModel(opt, model_id, model_root=model_root, load=True)
|
sm = ServerModel(opt, model_id, model_root=model_root, load=True)
|
||||||
inp = [{"src": "hello how are you today"},
|
inp = [{"src": "hello how are you today"},
|
||||||
{"src": "good morning to you ."}]
|
{"src": "good morning to you ."}]
|
||||||
results, scores, n_best, time = sm.run(inp)
|
results, scores, n_best, time, aligns = sm.run(inp)
|
||||||
self.assertIsInstance(results, list)
|
self.assertIsInstance(results, list)
|
||||||
for sentence_string in results:
|
for sentence_string in results:
|
||||||
self.assertIsInstance(sentence_string, string_types)
|
self.assertIsInstance(sentence_string, string_types)
|
||||||
self.assertIsInstance(scores, list)
|
self.assertIsInstance(scores, list)
|
||||||
for elem in scores:
|
for elem in scores:
|
||||||
self.assertIsInstance(elem, float)
|
self.assertIsInstance(elem, float)
|
||||||
|
self.assertIsInstance(aligns, list)
|
||||||
|
for align_list in aligns:
|
||||||
|
for align_string in align_list:
|
||||||
|
if align_string is not None:
|
||||||
|
self.assertIsInstance(align_string, string_types)
|
||||||
self.assertEqual(len(results), len(scores))
|
self.assertEqual(len(results), len(scores))
|
||||||
self.assertEqual(len(scores), len(inp))
|
self.assertEqual(len(scores), len(inp) * n_best)
|
||||||
self.assertEqual(n_best, 1)
|
|
||||||
self.assertEqual(len(time), 1)
|
self.assertEqual(len(time), 1)
|
||||||
self.assertIsInstance(time, dict)
|
self.assertIsInstance(time, dict)
|
||||||
self.assertIn("translation", time)
|
self.assertIn("translation", time)
|
||||||
|
|
||||||
def test_nbest_init_fails(self):
|
|
||||||
model_id = 0
|
|
||||||
opt = {"models": ["test_model.pt"], "n_best": 2}
|
|
||||||
model_root = TEST_DIR
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
ServerModel(opt, model_id, model_root=model_root, load=True)
|
|
||||||
|
|
||||||
|
|
||||||
class TestTranslationServer(unittest.TestCase):
|
class TestTranslationServer(unittest.TestCase):
|
||||||
# this could be considered an integration test because it touches
|
# this could be considered an integration test because it touches
|
||||||
|
|
|
@ -4,7 +4,7 @@ import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from onmt.inputters.inputter import build_dataset_iter, \
|
from onmt.inputters.inputter import build_dataset_iter, patch_fields, \
|
||||||
load_old_vocab, old_style_vocab, build_dataset_iter_multiple
|
load_old_vocab, old_style_vocab, build_dataset_iter_multiple
|
||||||
from onmt.model_builder import build_model
|
from onmt.model_builder import build_model
|
||||||
from onmt.utils.optimizers import Optimizer
|
from onmt.utils.optimizers import Optimizer
|
||||||
|
@ -69,6 +69,9 @@ def main(opt, device_id, batch_queue=None, semaphore=None):
|
||||||
else:
|
else:
|
||||||
fields = vocab
|
fields = vocab
|
||||||
|
|
||||||
|
# patch for fields that may be missing in old data/model
|
||||||
|
patch_fields(opt, fields)
|
||||||
|
|
||||||
# Report src and tgt vocab sizes, including for features
|
# Report src and tgt vocab sizes, including for features
|
||||||
for side in ['src', 'tgt']:
|
for side in ['src', 'tgt']:
|
||||||
f = fields[side]
|
f = fields[side]
|
||||||
|
@ -142,5 +145,5 @@ def main(opt, device_id, batch_queue=None, semaphore=None):
|
||||||
valid_iter=valid_iter,
|
valid_iter=valid_iter,
|
||||||
valid_steps=opt.valid_steps)
|
valid_steps=opt.valid_steps)
|
||||||
|
|
||||||
if opt.tensorboard:
|
if trainer.report_manager.tensorboard_writer is not None:
|
||||||
trainer.report_manager.tensorboard_writer.close()
|
trainer.report_manager.tensorboard_writer.close()
|
||||||
|
|
|
@ -9,7 +9,6 @@
|
||||||
users of this library) for the strategy things we do.
|
users of this library) for the strategy things we do.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from copy import deepcopy
|
|
||||||
import torch
|
import torch
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -58,19 +57,39 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None):
|
||||||
opt.early_stopping, scorers=onmt.utils.scorers_from_opts(opt)) \
|
opt.early_stopping, scorers=onmt.utils.scorers_from_opts(opt)) \
|
||||||
if opt.early_stopping > 0 else None
|
if opt.early_stopping > 0 else None
|
||||||
|
|
||||||
report_manager = onmt.utils.build_report_manager(opt)
|
source_noise = None
|
||||||
|
if len(opt.src_noise) > 0:
|
||||||
|
src_field = dict(fields)["src"].base_field
|
||||||
|
corpus_id_field = dict(fields).get("corpus_id", None)
|
||||||
|
if corpus_id_field is not None:
|
||||||
|
ids_to_noise = corpus_id_field.numericalize(opt.data_to_noise)
|
||||||
|
else:
|
||||||
|
ids_to_noise = None
|
||||||
|
source_noise = onmt.modules.source_noise.MultiNoise(
|
||||||
|
opt.src_noise,
|
||||||
|
opt.src_noise_prob,
|
||||||
|
ids_to_noise=ids_to_noise,
|
||||||
|
pad_idx=src_field.pad_token,
|
||||||
|
end_of_sentence_mask=src_field.end_of_sentence_mask,
|
||||||
|
word_start_mask=src_field.word_start_mask,
|
||||||
|
device_id=device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
report_manager = onmt.utils.build_report_manager(opt, gpu_rank)
|
||||||
trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size,
|
trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size,
|
||||||
shard_size, norm_method,
|
shard_size, norm_method,
|
||||||
accum_count, accum_steps,
|
accum_count, accum_steps,
|
||||||
n_gpu, gpu_rank,
|
n_gpu, gpu_rank,
|
||||||
gpu_verbose_level, report_manager,
|
gpu_verbose_level, report_manager,
|
||||||
|
with_align=True if opt.lambda_align > 0 else False,
|
||||||
model_saver=model_saver if gpu_rank == 0 else None,
|
model_saver=model_saver if gpu_rank == 0 else None,
|
||||||
average_decay=average_decay,
|
average_decay=average_decay,
|
||||||
average_every=average_every,
|
average_every=average_every,
|
||||||
model_dtype=opt.model_dtype,
|
model_dtype=opt.model_dtype,
|
||||||
earlystopper=earlystopper,
|
earlystopper=earlystopper,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
dropout_steps=dropout_steps)
|
dropout_steps=dropout_steps,
|
||||||
|
source_noise=source_noise)
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
|
@ -104,10 +123,11 @@ class Trainer(object):
|
||||||
trunc_size=0, shard_size=32,
|
trunc_size=0, shard_size=32,
|
||||||
norm_method="sents", accum_count=[1],
|
norm_method="sents", accum_count=[1],
|
||||||
accum_steps=[0],
|
accum_steps=[0],
|
||||||
n_gpu=1, gpu_rank=1,
|
n_gpu=1, gpu_rank=1, gpu_verbose_level=0,
|
||||||
gpu_verbose_level=0, report_manager=None, model_saver=None,
|
report_manager=None, with_align=False, model_saver=None,
|
||||||
average_decay=0, average_every=1, model_dtype='fp32',
|
average_decay=0, average_every=1, model_dtype='fp32',
|
||||||
earlystopper=None, dropout=[0.3], dropout_steps=[0]):
|
earlystopper=None, dropout=[0.3], dropout_steps=[0],
|
||||||
|
source_noise=None):
|
||||||
# Basic attributes.
|
# Basic attributes.
|
||||||
self.model = model
|
self.model = model
|
||||||
self.train_loss = train_loss
|
self.train_loss = train_loss
|
||||||
|
@ -123,6 +143,7 @@ class Trainer(object):
|
||||||
self.gpu_rank = gpu_rank
|
self.gpu_rank = gpu_rank
|
||||||
self.gpu_verbose_level = gpu_verbose_level
|
self.gpu_verbose_level = gpu_verbose_level
|
||||||
self.report_manager = report_manager
|
self.report_manager = report_manager
|
||||||
|
self.with_align = with_align
|
||||||
self.model_saver = model_saver
|
self.model_saver = model_saver
|
||||||
self.average_decay = average_decay
|
self.average_decay = average_decay
|
||||||
self.moving_average = None
|
self.moving_average = None
|
||||||
|
@ -131,6 +152,7 @@ class Trainer(object):
|
||||||
self.earlystopper = earlystopper
|
self.earlystopper = earlystopper
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.dropout_steps = dropout_steps
|
self.dropout_steps = dropout_steps
|
||||||
|
self.source_noise = source_noise
|
||||||
|
|
||||||
for i in range(len(self.accum_count_l)):
|
for i in range(len(self.accum_count_l)):
|
||||||
assert self.accum_count_l[i] > 0
|
assert self.accum_count_l[i] > 0
|
||||||
|
@ -290,13 +312,16 @@ class Trainer(object):
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`nmt.Statistics`: validation loss statistics
|
:obj:`nmt.Statistics`: validation loss statistics
|
||||||
"""
|
"""
|
||||||
|
valid_model = self.model
|
||||||
if moving_average:
|
if moving_average:
|
||||||
valid_model = deepcopy(self.model)
|
# swap model params w/ moving average
|
||||||
|
# (and keep the original parameters)
|
||||||
|
model_params_data = []
|
||||||
for avg, param in zip(self.moving_average,
|
for avg, param in zip(self.moving_average,
|
||||||
valid_model.parameters()):
|
valid_model.parameters()):
|
||||||
param.data = avg.data
|
model_params_data.append(param.data)
|
||||||
else:
|
param.data = avg.data.half() if self.optim._fp16 == "legacy" \
|
||||||
valid_model = self.model
|
else avg.data
|
||||||
|
|
||||||
# Set model in validating mode.
|
# Set model in validating mode.
|
||||||
valid_model.eval()
|
valid_model.eval()
|
||||||
|
@ -310,17 +335,19 @@ class Trainer(object):
|
||||||
tgt = batch.tgt
|
tgt = batch.tgt
|
||||||
|
|
||||||
# F-prop through the model.
|
# F-prop through the model.
|
||||||
outputs, attns = valid_model(src, tgt, src_lengths)
|
outputs, attns = valid_model(src, tgt, src_lengths,
|
||||||
|
with_align=self.with_align)
|
||||||
|
|
||||||
# Compute loss.
|
# Compute loss.
|
||||||
_, batch_stats = self.valid_loss(batch, outputs, attns)
|
_, batch_stats = self.valid_loss(batch, outputs, attns)
|
||||||
|
|
||||||
# Update statistics.
|
# Update statistics.
|
||||||
stats.update(batch_stats)
|
stats.update(batch_stats)
|
||||||
|
|
||||||
if moving_average:
|
if moving_average:
|
||||||
del valid_model
|
for param_data, param in zip(model_params_data,
|
||||||
else:
|
self.model.parameters()):
|
||||||
|
param.data = param_data
|
||||||
|
|
||||||
# Set model back to training mode.
|
# Set model back to training mode.
|
||||||
valid_model.train()
|
valid_model.train()
|
||||||
|
|
||||||
|
@ -339,6 +366,8 @@ class Trainer(object):
|
||||||
else:
|
else:
|
||||||
trunc_size = target_size
|
trunc_size = target_size
|
||||||
|
|
||||||
|
batch = self.maybe_noise_source(batch)
|
||||||
|
|
||||||
src, src_lengths = batch.src if isinstance(batch.src, tuple) \
|
src, src_lengths = batch.src if isinstance(batch.src, tuple) \
|
||||||
else (batch.src, None)
|
else (batch.src, None)
|
||||||
if src_lengths is not None:
|
if src_lengths is not None:
|
||||||
|
@ -354,7 +383,9 @@ class Trainer(object):
|
||||||
# 2. F-prop all but generator.
|
# 2. F-prop all but generator.
|
||||||
if self.accum_count == 1:
|
if self.accum_count == 1:
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt)
|
|
||||||
|
outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt,
|
||||||
|
with_align=self.with_align)
|
||||||
bptt = True
|
bptt = True
|
||||||
|
|
||||||
# 3. Compute loss.
|
# 3. Compute loss.
|
||||||
|
@ -454,3 +485,8 @@ class Trainer(object):
|
||||||
return self.report_manager.report_step(
|
return self.report_manager.report_step(
|
||||||
learning_rate, step, train_stats=train_stats,
|
learning_rate, step, train_stats=train_stats,
|
||||||
valid_stats=valid_stats)
|
valid_stats=valid_stats)
|
||||||
|
|
||||||
|
def maybe_noise_source(self, batch):
|
||||||
|
if self.source_noise is not None:
|
||||||
|
return self.source_noise(batch)
|
||||||
|
return batch
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
""" Modules for translation """
|
""" Modules for translation """
|
||||||
from onmt.translate.translator import Translator
|
from onmt.translate.translator import Translator
|
||||||
from onmt.translate.translation import Translation, TranslationBuilder
|
from onmt.translate.translation import Translation, TranslationBuilder
|
||||||
from onmt.translate.beam import Beam, GNMTGlobalScorer
|
from onmt.translate.beam_search import BeamSearch, GNMTGlobalScorer
|
||||||
from onmt.translate.beam_search import BeamSearch
|
|
||||||
from onmt.translate.decode_strategy import DecodeStrategy
|
from onmt.translate.decode_strategy import DecodeStrategy
|
||||||
from onmt.translate.random_sampling import RandomSampling
|
from onmt.translate.greedy_search import GreedySearch
|
||||||
from onmt.translate.penalties import PenaltyBuilder
|
from onmt.translate.penalties import PenaltyBuilder
|
||||||
from onmt.translate.translation_server import TranslationServer, \
|
from onmt.translate.translation_server import TranslationServer, \
|
||||||
ServerModelError
|
ServerModelError
|
||||||
|
|
||||||
__all__ = ['Translator', 'Translation', 'Beam', 'BeamSearch',
|
__all__ = ['Translator', 'Translation', 'BeamSearch',
|
||||||
'GNMTGlobalScorer', 'TranslationBuilder',
|
'GNMTGlobalScorer', 'TranslationBuilder',
|
||||||
'PenaltyBuilder', 'TranslationServer', 'ServerModelError',
|
'PenaltyBuilder', 'TranslationServer', 'ServerModelError',
|
||||||
"DecodeStrategy", "RandomSampling"]
|
"DecodeStrategy", "GreedySearch"]
|
||||||
|
|
|
@ -1,293 +0,0 @@
|
||||||
from __future__ import division
|
|
||||||
import torch
|
|
||||||
from onmt.translate import penalties
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
|
|
||||||
class Beam(object):
|
|
||||||
"""Class for managing the internals of the beam search process.
|
|
||||||
|
|
||||||
Takes care of beams, back pointers, and scores.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
size (int): Number of beams to use.
|
|
||||||
pad (int): Magic integer in output vocab.
|
|
||||||
bos (int): Magic integer in output vocab.
|
|
||||||
eos (int): Magic integer in output vocab.
|
|
||||||
n_best (int): Don't stop until at least this many beams have
|
|
||||||
reached EOS.
|
|
||||||
cuda (bool): use gpu
|
|
||||||
global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance.
|
|
||||||
min_length (int): Shortest acceptable generation, not counting
|
|
||||||
begin-of-sentence or end-of-sentence.
|
|
||||||
stepwise_penalty (bool): Apply coverage penalty at every step.
|
|
||||||
block_ngram_repeat (int): Block beams where
|
|
||||||
``block_ngram_repeat``-grams repeat.
|
|
||||||
exclusion_tokens (set[int]): If a gram contains any of these
|
|
||||||
token indices, it may repeat.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, size, pad, bos, eos,
|
|
||||||
n_best=1, cuda=False,
|
|
||||||
global_scorer=None,
|
|
||||||
min_length=0,
|
|
||||||
stepwise_penalty=False,
|
|
||||||
block_ngram_repeat=0,
|
|
||||||
exclusion_tokens=set()):
|
|
||||||
|
|
||||||
self.size = size
|
|
||||||
self.tt = torch.cuda if cuda else torch
|
|
||||||
|
|
||||||
# The score for each translation on the beam.
|
|
||||||
self.scores = self.tt.FloatTensor(size).zero_()
|
|
||||||
self.all_scores = []
|
|
||||||
|
|
||||||
# The backpointers at each time-step.
|
|
||||||
self.prev_ks = []
|
|
||||||
|
|
||||||
# The outputs at each time-step.
|
|
||||||
self.next_ys = [self.tt.LongTensor(size)
|
|
||||||
.fill_(pad)]
|
|
||||||
self.next_ys[0][0] = bos
|
|
||||||
|
|
||||||
# Has EOS topped the beam yet.
|
|
||||||
self._eos = eos
|
|
||||||
self.eos_top = False
|
|
||||||
|
|
||||||
# The attentions (matrix) for each time.
|
|
||||||
self.attn = []
|
|
||||||
|
|
||||||
# Time and k pair for finished.
|
|
||||||
self.finished = []
|
|
||||||
self.n_best = n_best
|
|
||||||
|
|
||||||
# Information for global scoring.
|
|
||||||
self.global_scorer = global_scorer
|
|
||||||
self.global_state = {}
|
|
||||||
|
|
||||||
# Minimum prediction length
|
|
||||||
self.min_length = min_length
|
|
||||||
|
|
||||||
# Apply Penalty at every step
|
|
||||||
self.stepwise_penalty = stepwise_penalty
|
|
||||||
self.block_ngram_repeat = block_ngram_repeat
|
|
||||||
self.exclusion_tokens = exclusion_tokens
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_predictions(self):
|
|
||||||
return self.next_ys[-1]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_origin(self):
|
|
||||||
"""Get the backpointers for the current timestep."""
|
|
||||||
return self.prev_ks[-1]
|
|
||||||
|
|
||||||
def advance(self, word_probs, attn_out):
|
|
||||||
"""
|
|
||||||
Given prob over words for every last beam `wordLk` and attention
|
|
||||||
`attn_out`: Compute and update the beam search.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
word_probs (FloatTensor): probs of advancing from the last step
|
|
||||||
``(K, words)``
|
|
||||||
attn_out (FloatTensor): attention at the last step
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if beam search is complete.
|
|
||||||
"""
|
|
||||||
|
|
||||||
num_words = word_probs.size(1)
|
|
||||||
if self.stepwise_penalty:
|
|
||||||
self.global_scorer.update_score(self, attn_out)
|
|
||||||
# force the output to be longer than self.min_length
|
|
||||||
cur_len = len(self.next_ys)
|
|
||||||
if cur_len <= self.min_length:
|
|
||||||
# assumes there are len(word_probs) predictions OTHER
|
|
||||||
# than EOS that are greater than -1e20
|
|
||||||
for k in range(len(word_probs)):
|
|
||||||
word_probs[k][self._eos] = -1e20
|
|
||||||
|
|
||||||
# Sum the previous scores.
|
|
||||||
if len(self.prev_ks) > 0:
|
|
||||||
beam_scores = word_probs + self.scores.unsqueeze(1)
|
|
||||||
# Don't let EOS have children.
|
|
||||||
for i in range(self.next_ys[-1].size(0)):
|
|
||||||
if self.next_ys[-1][i] == self._eos:
|
|
||||||
beam_scores[i] = -1e20
|
|
||||||
|
|
||||||
# Block ngram repeats
|
|
||||||
if self.block_ngram_repeat > 0:
|
|
||||||
le = len(self.next_ys)
|
|
||||||
for j in range(self.next_ys[-1].size(0)):
|
|
||||||
hyp, _ = self.get_hyp(le - 1, j)
|
|
||||||
ngrams = set()
|
|
||||||
fail = False
|
|
||||||
gram = []
|
|
||||||
for i in range(le - 1):
|
|
||||||
# Last n tokens, n = block_ngram_repeat
|
|
||||||
gram = (gram +
|
|
||||||
[hyp[i].item()])[-self.block_ngram_repeat:]
|
|
||||||
# Skip the blocking if it is in the exclusion list
|
|
||||||
if set(gram) & self.exclusion_tokens:
|
|
||||||
continue
|
|
||||||
if tuple(gram) in ngrams:
|
|
||||||
fail = True
|
|
||||||
ngrams.add(tuple(gram))
|
|
||||||
if fail:
|
|
||||||
beam_scores[j] = -10e20
|
|
||||||
else:
|
|
||||||
beam_scores = word_probs[0]
|
|
||||||
flat_beam_scores = beam_scores.view(-1)
|
|
||||||
best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0,
|
|
||||||
True, True)
|
|
||||||
|
|
||||||
self.all_scores.append(self.scores)
|
|
||||||
self.scores = best_scores
|
|
||||||
|
|
||||||
# best_scores_id is flattened beam x word array, so calculate which
|
|
||||||
# word and beam each score came from
|
|
||||||
prev_k = best_scores_id / num_words
|
|
||||||
self.prev_ks.append(prev_k)
|
|
||||||
self.next_ys.append((best_scores_id - prev_k * num_words))
|
|
||||||
self.attn.append(attn_out.index_select(0, prev_k))
|
|
||||||
self.global_scorer.update_global_state(self)
|
|
||||||
|
|
||||||
for i in range(self.next_ys[-1].size(0)):
|
|
||||||
if self.next_ys[-1][i] == self._eos:
|
|
||||||
global_scores = self.global_scorer.score(self, self.scores)
|
|
||||||
s = global_scores[i]
|
|
||||||
self.finished.append((s, len(self.next_ys) - 1, i))
|
|
||||||
|
|
||||||
# End condition is when top-of-beam is EOS and no global score.
|
|
||||||
if self.next_ys[-1][0] == self._eos:
|
|
||||||
self.all_scores.append(self.scores)
|
|
||||||
self.eos_top = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def done(self):
|
|
||||||
return self.eos_top and len(self.finished) >= self.n_best
|
|
||||||
|
|
||||||
def sort_finished(self, minimum=None):
|
|
||||||
if minimum is not None:
|
|
||||||
i = 0
|
|
||||||
# Add from beam until we have minimum outputs.
|
|
||||||
while len(self.finished) < minimum:
|
|
||||||
global_scores = self.global_scorer.score(self, self.scores)
|
|
||||||
s = global_scores[i]
|
|
||||||
self.finished.append((s, len(self.next_ys) - 1, i))
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
self.finished.sort(key=lambda a: -a[0])
|
|
||||||
scores = [sc for sc, _, _ in self.finished]
|
|
||||||
ks = [(t, k) for _, t, k in self.finished]
|
|
||||||
return scores, ks
|
|
||||||
|
|
||||||
def get_hyp(self, timestep, k):
|
|
||||||
"""Walk back to construct the full hypothesis."""
|
|
||||||
hyp, attn = [], []
|
|
||||||
for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
|
|
||||||
hyp.append(self.next_ys[j + 1][k])
|
|
||||||
attn.append(self.attn[j][k])
|
|
||||||
k = self.prev_ks[j][k]
|
|
||||||
return hyp[::-1], torch.stack(attn[::-1])
|
|
||||||
|
|
||||||
|
|
||||||
class GNMTGlobalScorer(object):
|
|
||||||
"""NMT re-ranking.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
alpha (float): Length parameter.
|
|
||||||
beta (float): Coverage parameter.
|
|
||||||
length_penalty (str): Length penalty strategy.
|
|
||||||
coverage_penalty (str): Coverage penalty strategy.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
alpha (float): See above.
|
|
||||||
beta (float): See above.
|
|
||||||
length_penalty (callable): See :class:`penalties.PenaltyBuilder`.
|
|
||||||
coverage_penalty (callable): See :class:`penalties.PenaltyBuilder`.
|
|
||||||
has_cov_pen (bool): See :class:`penalties.PenaltyBuilder`.
|
|
||||||
has_len_pen (bool): See :class:`penalties.PenaltyBuilder`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_opt(cls, opt):
|
|
||||||
return cls(
|
|
||||||
opt.alpha,
|
|
||||||
opt.beta,
|
|
||||||
opt.length_penalty,
|
|
||||||
opt.coverage_penalty)
|
|
||||||
|
|
||||||
def __init__(self, alpha, beta, length_penalty, coverage_penalty):
|
|
||||||
self._validate(alpha, beta, length_penalty, coverage_penalty)
|
|
||||||
self.alpha = alpha
|
|
||||||
self.beta = beta
|
|
||||||
penalty_builder = penalties.PenaltyBuilder(coverage_penalty,
|
|
||||||
length_penalty)
|
|
||||||
self.has_cov_pen = penalty_builder.has_cov_pen
|
|
||||||
# Term will be subtracted from probability
|
|
||||||
self.cov_penalty = penalty_builder.coverage_penalty
|
|
||||||
|
|
||||||
self.has_len_pen = penalty_builder.has_len_pen
|
|
||||||
# Probability will be divided by this
|
|
||||||
self.length_penalty = penalty_builder.length_penalty
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate(cls, alpha, beta, length_penalty, coverage_penalty):
|
|
||||||
# these warnings indicate that either the alpha/beta
|
|
||||||
# forces a penalty to be a no-op, or a penalty is a no-op but
|
|
||||||
# the alpha/beta would suggest otherwise.
|
|
||||||
if length_penalty is None or length_penalty == "none":
|
|
||||||
if alpha != 0:
|
|
||||||
warnings.warn("Non-default `alpha` with no length penalty. "
|
|
||||||
"`alpha` has no effect.")
|
|
||||||
else:
|
|
||||||
# using some length penalty
|
|
||||||
if length_penalty == "wu" and alpha == 0.:
|
|
||||||
warnings.warn("Using length penalty Wu with alpha==0 "
|
|
||||||
"is equivalent to using length penalty none.")
|
|
||||||
if coverage_penalty is None or coverage_penalty == "none":
|
|
||||||
if beta != 0:
|
|
||||||
warnings.warn("Non-default `beta` with no coverage penalty. "
|
|
||||||
"`beta` has no effect.")
|
|
||||||
else:
|
|
||||||
# using some coverage penalty
|
|
||||||
if beta == 0.:
|
|
||||||
warnings.warn("Non-default coverage penalty with beta==0 "
|
|
||||||
"is equivalent to using coverage penalty none.")
|
|
||||||
|
|
||||||
def score(self, beam, logprobs):
|
|
||||||
"""Rescore a prediction based on penalty functions."""
|
|
||||||
len_pen = self.length_penalty(len(beam.next_ys), self.alpha)
|
|
||||||
normalized_probs = logprobs / len_pen
|
|
||||||
if not beam.stepwise_penalty:
|
|
||||||
penalty = self.cov_penalty(beam.global_state["coverage"],
|
|
||||||
self.beta)
|
|
||||||
normalized_probs -= penalty
|
|
||||||
|
|
||||||
return normalized_probs
|
|
||||||
|
|
||||||
def update_score(self, beam, attn):
|
|
||||||
"""Update scores of a Beam that is not finished."""
|
|
||||||
if "prev_penalty" in beam.global_state.keys():
|
|
||||||
beam.scores.add_(beam.global_state["prev_penalty"])
|
|
||||||
penalty = self.cov_penalty(beam.global_state["coverage"] + attn,
|
|
||||||
self.beta)
|
|
||||||
beam.scores.sub_(penalty)
|
|
||||||
|
|
||||||
def update_global_state(self, beam):
|
|
||||||
"""Keeps the coverage vector as sum of attentions."""
|
|
||||||
if len(beam.prev_ks) == 1:
|
|
||||||
beam.global_state["prev_penalty"] = beam.scores.clone().fill_(0.0)
|
|
||||||
beam.global_state["coverage"] = beam.attn[-1]
|
|
||||||
self.cov_total = beam.attn[-1].sum(1)
|
|
||||||
else:
|
|
||||||
self.cov_total += torch.min(beam.attn[-1],
|
|
||||||
beam.global_state['coverage']).sum(1)
|
|
||||||
beam.global_state["coverage"] = beam.global_state["coverage"] \
|
|
||||||
.index_select(0, beam.prev_ks[-1]).add(beam.attn[-1])
|
|
||||||
|
|
||||||
prev_penalty = self.cov_penalty(beam.global_state["coverage"],
|
|
||||||
self.beta)
|
|
||||||
beam.global_state["prev_penalty"] = prev_penalty
|
|
|
@ -1,6 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
|
from onmt.translate import penalties
|
||||||
from onmt.translate.decode_strategy import DecodeStrategy
|
from onmt.translate.decode_strategy import DecodeStrategy
|
||||||
|
from onmt.utils.misc import tile
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
class BeamSearch(DecodeStrategy):
|
class BeamSearch(DecodeStrategy):
|
||||||
|
@ -19,15 +22,12 @@ class BeamSearch(DecodeStrategy):
|
||||||
eos (int): See base.
|
eos (int): See base.
|
||||||
n_best (int): Don't stop until at least this many beams have
|
n_best (int): Don't stop until at least this many beams have
|
||||||
reached EOS.
|
reached EOS.
|
||||||
mb_device (torch.device or str): See base ``device``.
|
|
||||||
global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance.
|
global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance.
|
||||||
min_length (int): See base.
|
min_length (int): See base.
|
||||||
max_length (int): See base.
|
max_length (int): See base.
|
||||||
return_attention (bool): See base.
|
return_attention (bool): See base.
|
||||||
block_ngram_repeat (int): See base.
|
block_ngram_repeat (int): See base.
|
||||||
exclusion_tokens (set[int]): See base.
|
exclusion_tokens (set[int]): See base.
|
||||||
memory_lengths (LongTensor): Lengths of encodings. Used for
|
|
||||||
masking attentions.
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
top_beam_finished (ByteTensor): Shape ``(B,)``.
|
top_beam_finished (ByteTensor): Shape ``(B,)``.
|
||||||
|
@ -36,6 +36,8 @@ class BeamSearch(DecodeStrategy):
|
||||||
alive_seq (LongTensor): See base.
|
alive_seq (LongTensor): See base.
|
||||||
topk_log_probs (FloatTensor): Shape ``(B x beam_size,)``. These
|
topk_log_probs (FloatTensor): Shape ``(B x beam_size,)``. These
|
||||||
are the scores used for the topk operation.
|
are the scores used for the topk operation.
|
||||||
|
memory_lengths (LongTensor): Lengths of encodings. Used for
|
||||||
|
masking attentions.
|
||||||
select_indices (LongTensor or NoneType): Shape
|
select_indices (LongTensor or NoneType): Shape
|
||||||
``(B x beam_size,)``. This is just a flat view of the
|
``(B x beam_size,)``. This is just a flat view of the
|
||||||
``_batch_index``.
|
``_batch_index``.
|
||||||
|
@ -53,19 +55,18 @@ class BeamSearch(DecodeStrategy):
|
||||||
of score (float), sequence (long), and attention (float or None).
|
of score (float), sequence (long), and attention (float or None).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, beam_size, batch_size, pad, bos, eos, n_best, mb_device,
|
def __init__(self, beam_size, batch_size, pad, bos, eos, n_best,
|
||||||
global_scorer, min_length, max_length, return_attention,
|
global_scorer, min_length, max_length, return_attention,
|
||||||
block_ngram_repeat, exclusion_tokens, memory_lengths,
|
block_ngram_repeat, exclusion_tokens,
|
||||||
stepwise_penalty, ratio):
|
stepwise_penalty, ratio):
|
||||||
super(BeamSearch, self).__init__(
|
super(BeamSearch, self).__init__(
|
||||||
pad, bos, eos, batch_size, mb_device, beam_size, min_length,
|
pad, bos, eos, batch_size, beam_size, min_length,
|
||||||
block_ngram_repeat, exclusion_tokens, return_attention,
|
block_ngram_repeat, exclusion_tokens, return_attention,
|
||||||
max_length)
|
max_length)
|
||||||
# beam parameters
|
# beam parameters
|
||||||
self.global_scorer = global_scorer
|
self.global_scorer = global_scorer
|
||||||
self.beam_size = beam_size
|
self.beam_size = beam_size
|
||||||
self.n_best = n_best
|
self.n_best = n_best
|
||||||
self.batch_size = batch_size
|
|
||||||
self.ratio = ratio
|
self.ratio = ratio
|
||||||
|
|
||||||
# result caching
|
# result caching
|
||||||
|
@ -73,26 +74,14 @@ class BeamSearch(DecodeStrategy):
|
||||||
|
|
||||||
# beam state
|
# beam state
|
||||||
self.top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8)
|
self.top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8)
|
||||||
self.best_scores = torch.full([batch_size], -1e10, dtype=torch.float,
|
# BoolTensor was introduced in pytorch 1.2
|
||||||
device=mb_device)
|
try:
|
||||||
|
self.top_beam_finished = self.top_beam_finished.bool()
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
self._batch_offset = torch.arange(batch_size, dtype=torch.long)
|
self._batch_offset = torch.arange(batch_size, dtype=torch.long)
|
||||||
self._beam_offset = torch.arange(
|
|
||||||
0, batch_size * beam_size, step=beam_size, dtype=torch.long,
|
|
||||||
device=mb_device)
|
|
||||||
self.topk_log_probs = torch.tensor(
|
|
||||||
[0.0] + [float("-inf")] * (beam_size - 1), device=mb_device
|
|
||||||
).repeat(batch_size)
|
|
||||||
self.select_indices = None
|
|
||||||
self._memory_lengths = memory_lengths
|
|
||||||
|
|
||||||
# buffers for the topk scores and 'backpointer'
|
self.select_indices = None
|
||||||
self.topk_scores = torch.empty((batch_size, beam_size),
|
|
||||||
dtype=torch.float, device=mb_device)
|
|
||||||
self.topk_ids = torch.empty((batch_size, beam_size), dtype=torch.long,
|
|
||||||
device=mb_device)
|
|
||||||
self._batch_index = torch.empty([batch_size, beam_size],
|
|
||||||
dtype=torch.long, device=mb_device)
|
|
||||||
self.done = False
|
self.done = False
|
||||||
# "global state" of the old beam
|
# "global state" of the old beam
|
||||||
self._prev_penalty = None
|
self._prev_penalty = None
|
||||||
|
@ -104,20 +93,60 @@ class BeamSearch(DecodeStrategy):
|
||||||
not stepwise_penalty and self.global_scorer.has_cov_pen)
|
not stepwise_penalty and self.global_scorer.has_cov_pen)
|
||||||
self._cov_pen = self.global_scorer.has_cov_pen
|
self._cov_pen = self.global_scorer.has_cov_pen
|
||||||
|
|
||||||
|
def initialize(self, memory_bank, src_lengths, src_map=None, device=None):
|
||||||
|
"""Initialize for decoding.
|
||||||
|
Repeat src objects `beam_size` times.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fn_map_state(state, dim):
|
||||||
|
return tile(state, self.beam_size, dim=dim)
|
||||||
|
|
||||||
|
if isinstance(memory_bank, tuple):
|
||||||
|
memory_bank = tuple(tile(x, self.beam_size, dim=1)
|
||||||
|
for x in memory_bank)
|
||||||
|
mb_device = memory_bank[0].device
|
||||||
|
else:
|
||||||
|
memory_bank = tile(memory_bank, self.beam_size, dim=1)
|
||||||
|
mb_device = memory_bank.device
|
||||||
|
if src_map is not None:
|
||||||
|
src_map = tile(src_map, self.beam_size, dim=1)
|
||||||
|
if device is None:
|
||||||
|
device = mb_device
|
||||||
|
|
||||||
|
self.memory_lengths = tile(src_lengths, self.beam_size)
|
||||||
|
super(BeamSearch, self).initialize(
|
||||||
|
memory_bank, self.memory_lengths, src_map, device)
|
||||||
|
self.best_scores = torch.full(
|
||||||
|
[self.batch_size], -1e10, dtype=torch.float, device=device)
|
||||||
|
self._beam_offset = torch.arange(
|
||||||
|
0, self.batch_size * self.beam_size, step=self.beam_size,
|
||||||
|
dtype=torch.long, device=device)
|
||||||
|
self.topk_log_probs = torch.tensor(
|
||||||
|
[0.0] + [float("-inf")] * (self.beam_size - 1), device=device
|
||||||
|
).repeat(self.batch_size)
|
||||||
|
# buffers for the topk scores and 'backpointer'
|
||||||
|
self.topk_scores = torch.empty((self.batch_size, self.beam_size),
|
||||||
|
dtype=torch.float, device=device)
|
||||||
|
self.topk_ids = torch.empty((self.batch_size, self.beam_size),
|
||||||
|
dtype=torch.long, device=device)
|
||||||
|
self._batch_index = torch.empty([self.batch_size, self.beam_size],
|
||||||
|
dtype=torch.long, device=device)
|
||||||
|
return fn_map_state, memory_bank, self.memory_lengths, src_map
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_predictions(self):
|
def current_predictions(self):
|
||||||
return self.alive_seq[:, -1]
|
return self.alive_seq[:, -1]
|
||||||
|
|
||||||
@property
|
|
||||||
def current_origin(self):
|
|
||||||
return self.select_indices
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_backptr(self):
|
def current_backptr(self):
|
||||||
# for testing
|
# for testing
|
||||||
return self.select_indices.view(self.batch_size, self.beam_size)\
|
return self.select_indices.view(self.batch_size, self.beam_size)\
|
||||||
.fmod(self.beam_size)
|
.fmod(self.beam_size)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_offset(self):
|
||||||
|
return self._batch_offset
|
||||||
|
|
||||||
def advance(self, log_probs, attn, partial=None, partialf=None):
|
def advance(self, log_probs, attn, partial=None, partialf=None):
|
||||||
vocab_size = log_probs.size(-1)
|
vocab_size = log_probs.size(-1)
|
||||||
|
|
||||||
|
@ -137,16 +166,19 @@ class BeamSearch(DecodeStrategy):
|
||||||
# Multiply probs by the beam probability.
|
# Multiply probs by the beam probability.
|
||||||
log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)
|
log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)
|
||||||
|
|
||||||
self.block_ngram_repeats(log_probs)
|
|
||||||
|
|
||||||
# if the sequence ends now, then the penalty is the current
|
# if the sequence ends now, then the penalty is the current
|
||||||
# length + 1, to include the EOS token
|
# length + 1, to include the EOS token
|
||||||
length_penalty = self.global_scorer.length_penalty(
|
length_penalty = self.global_scorer.length_penalty(
|
||||||
step + 1, alpha=self.global_scorer.alpha)
|
step + 1, alpha=self.global_scorer.alpha)
|
||||||
|
|
||||||
# Flatten probs into a list of possibilities.
|
|
||||||
curr_scores = log_probs / length_penalty
|
curr_scores = log_probs / length_penalty
|
||||||
|
|
||||||
|
# Avoid any direction that would repeat unwanted ngrams
|
||||||
|
self.block_ngram_repeats(curr_scores)
|
||||||
|
|
||||||
|
# Flatten probs into a list of possibilities.
|
||||||
curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size)
|
curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size)
|
||||||
|
|
||||||
if partialf and step and step == len(partial) + 1:
|
if partialf and step and step == len(partial) + 1:
|
||||||
# proxyf = []
|
# proxyf = []
|
||||||
# for i in range(self.beam_size):
|
# for i in range(self.beam_size):
|
||||||
|
@ -169,7 +201,6 @@ class BeamSearch(DecodeStrategy):
|
||||||
|
|
||||||
torch.topk(curr_scores, self.beam_size, dim=-1,
|
torch.topk(curr_scores, self.beam_size, dim=-1,
|
||||||
out=(self.topk_scores, self.topk_ids))
|
out=(self.topk_scores, self.topk_ids))
|
||||||
|
|
||||||
# if partialf and step and step == len(partial) + 1:
|
# if partialf and step and step == len(partial) + 1:
|
||||||
# print(self.topk_scores)
|
# print(self.topk_scores)
|
||||||
# Recover log probs.
|
# Recover log probs.
|
||||||
|
@ -181,18 +212,18 @@ class BeamSearch(DecodeStrategy):
|
||||||
torch.div(self.topk_ids, vocab_size, out=self._batch_index)
|
torch.div(self.topk_ids, vocab_size, out=self._batch_index)
|
||||||
self._batch_index += self._beam_offset[:_B].unsqueeze(1)
|
self._batch_index += self._beam_offset[:_B].unsqueeze(1)
|
||||||
self.select_indices = self._batch_index.view(_B * self.beam_size)
|
self.select_indices = self._batch_index.view(_B * self.beam_size)
|
||||||
|
|
||||||
self.topk_ids.fmod_(vocab_size) # resolve true word ids
|
self.topk_ids.fmod_(vocab_size) # resolve true word ids
|
||||||
|
|
||||||
# Append last prediction.
|
# Append last prediction.
|
||||||
if partial and step and step <= len(partial):
|
if partial and step and step <= len(partial):
|
||||||
self.topk_ids = torch.full([self.batch_size * self.beam_size, 1], partial[step-1], dtype=torch.long, device=torch.device('cpu'))
|
self.topk_ids = torch.full([self.batch_size * self.beam_size, 1], partial[step-1], dtype=torch.long, device=torch.device('cpu'))
|
||||||
self.select_indices = torch.tensor([0, 0, 0, 0, 0])
|
self.select_indices = torch.tensor([0, 0, 0, 0, 0])
|
||||||
|
|
||||||
|
|
||||||
self.alive_seq = torch.cat(
|
self.alive_seq = torch.cat(
|
||||||
[self.alive_seq.index_select(0, self.select_indices),
|
[self.alive_seq.index_select(0, self.select_indices),
|
||||||
self.topk_ids.view(_B * self.beam_size, 1)], -1)
|
self.topk_ids.view(_B * self.beam_size, 1)], -1)
|
||||||
|
|
||||||
|
self.maybe_update_forbidden_tokens()
|
||||||
|
|
||||||
if self.return_attention or self._cov_pen:
|
if self.return_attention or self._cov_pen:
|
||||||
current_attn = attn.index_select(1, self.select_indices)
|
current_attn = attn.index_select(1, self.select_indices)
|
||||||
if step == 1:
|
if step == 1:
|
||||||
|
@ -219,10 +250,9 @@ class BeamSearch(DecodeStrategy):
|
||||||
cov_penalty = self.global_scorer.cov_penalty(
|
cov_penalty = self.global_scorer.cov_penalty(
|
||||||
self._coverage,
|
self._coverage,
|
||||||
beta=self.global_scorer.beta)
|
beta=self.global_scorer.beta)
|
||||||
self.topk_scores -= cov_penalty.view(_B, self.beam_size)
|
self.topk_scores -= cov_penalty.view(_B, self.beam_size).float()
|
||||||
|
|
||||||
self.is_finished = self.topk_ids.eq(self.eos)
|
self.is_finished = self.topk_ids.eq(self.eos)
|
||||||
# print(self.is_finished, self.topk_ids)
|
|
||||||
self.ensure_max_length()
|
self.ensure_max_length()
|
||||||
|
|
||||||
def update_finished(self):
|
def update_finished(self):
|
||||||
|
@ -240,11 +270,11 @@ class BeamSearch(DecodeStrategy):
|
||||||
step - 1, _B_old, self.beam_size, self.alive_attn.size(-1))
|
step - 1, _B_old, self.beam_size, self.alive_attn.size(-1))
|
||||||
if self.alive_attn is not None else None)
|
if self.alive_attn is not None else None)
|
||||||
non_finished_batch = []
|
non_finished_batch = []
|
||||||
for i in range(self.is_finished.size(0)):
|
for i in range(self.is_finished.size(0)): # Batch level
|
||||||
b = self._batch_offset[i]
|
b = self._batch_offset[i]
|
||||||
finished_hyp = self.is_finished[i].nonzero().view(-1)
|
finished_hyp = self.is_finished[i].nonzero().view(-1)
|
||||||
# Store finished hypotheses for this batch.
|
# Store finished hypotheses for this batch.
|
||||||
for j in finished_hyp:
|
for j in finished_hyp: # Beam level: finished beam j in batch i
|
||||||
if self.ratio > 0:
|
if self.ratio > 0:
|
||||||
s = self.topk_scores[i, j] / (step + 1)
|
s = self.topk_scores[i, j] / (step + 1)
|
||||||
if self.best_scores[b] < s:
|
if self.best_scores[b] < s:
|
||||||
|
@ -252,12 +282,12 @@ class BeamSearch(DecodeStrategy):
|
||||||
self.hypotheses[b].append((
|
self.hypotheses[b].append((
|
||||||
self.topk_scores[i, j],
|
self.topk_scores[i, j],
|
||||||
predictions[i, j, 1:], # Ignore start_token.
|
predictions[i, j, 1:], # Ignore start_token.
|
||||||
attention[:, i, j, :self._memory_lengths[i]]
|
attention[:, i, j, :self.memory_lengths[i]]
|
||||||
if attention is not None else None))
|
if attention is not None else None))
|
||||||
# End condition is the top beam finished and we can return
|
# End condition is the top beam finished and we can return
|
||||||
# n_best hypotheses.
|
# n_best hypotheses.
|
||||||
if self.ratio > 0:
|
if self.ratio > 0:
|
||||||
pred_len = self._memory_lengths[i] * self.ratio
|
pred_len = self.memory_lengths[i] * self.ratio
|
||||||
finish_flag = ((self.topk_scores[i, 0] / pred_len)
|
finish_flag = ((self.topk_scores[i, 0] / pred_len)
|
||||||
<= self.best_scores[b]) or \
|
<= self.best_scores[b]) or \
|
||||||
self.is_finished[i].all()
|
self.is_finished[i].all()
|
||||||
|
@ -270,7 +300,7 @@ class BeamSearch(DecodeStrategy):
|
||||||
if n >= self.n_best:
|
if n >= self.n_best:
|
||||||
break
|
break
|
||||||
self.scores[b].append(score)
|
self.scores[b].append(score)
|
||||||
self.predictions[b].append(pred)
|
self.predictions[b].append(pred) # ``(batch, n_best,)``
|
||||||
self.attention[b].append(
|
self.attention[b].append(
|
||||||
attn if attn is not None else [])
|
attn if attn is not None else [])
|
||||||
else:
|
else:
|
||||||
|
@ -307,3 +337,68 @@ class BeamSearch(DecodeStrategy):
|
||||||
if self._stepwise_cov_pen:
|
if self._stepwise_cov_pen:
|
||||||
self._prev_penalty = self._prev_penalty.index_select(
|
self._prev_penalty = self._prev_penalty.index_select(
|
||||||
0, non_finished)
|
0, non_finished)
|
||||||
|
|
||||||
|
|
||||||
|
class GNMTGlobalScorer(object):
|
||||||
|
"""NMT re-ranking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alpha (float): Length parameter.
|
||||||
|
beta (float): Coverage parameter.
|
||||||
|
length_penalty (str): Length penalty strategy.
|
||||||
|
coverage_penalty (str): Coverage penalty strategy.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
alpha (float): See above.
|
||||||
|
beta (float): See above.
|
||||||
|
length_penalty (callable): See :class:`penalties.PenaltyBuilder`.
|
||||||
|
coverage_penalty (callable): See :class:`penalties.PenaltyBuilder`.
|
||||||
|
has_cov_pen (bool): See :class:`penalties.PenaltyBuilder`.
|
||||||
|
has_len_pen (bool): See :class:`penalties.PenaltyBuilder`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_opt(cls, opt):
|
||||||
|
return cls(
|
||||||
|
opt.alpha,
|
||||||
|
opt.beta,
|
||||||
|
opt.length_penalty,
|
||||||
|
opt.coverage_penalty)
|
||||||
|
|
||||||
|
def __init__(self, alpha, beta, length_penalty, coverage_penalty):
|
||||||
|
self._validate(alpha, beta, length_penalty, coverage_penalty)
|
||||||
|
self.alpha = alpha
|
||||||
|
self.beta = beta
|
||||||
|
penalty_builder = penalties.PenaltyBuilder(coverage_penalty,
|
||||||
|
length_penalty)
|
||||||
|
self.has_cov_pen = penalty_builder.has_cov_pen
|
||||||
|
# Term will be subtracted from probability
|
||||||
|
self.cov_penalty = penalty_builder.coverage_penalty
|
||||||
|
|
||||||
|
self.has_len_pen = penalty_builder.has_len_pen
|
||||||
|
# Probability will be divided by this
|
||||||
|
self.length_penalty = penalty_builder.length_penalty
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate(cls, alpha, beta, length_penalty, coverage_penalty):
|
||||||
|
# these warnings indicate that either the alpha/beta
|
||||||
|
# forces a penalty to be a no-op, or a penalty is a no-op but
|
||||||
|
# the alpha/beta would suggest otherwise.
|
||||||
|
if length_penalty is None or length_penalty == "none":
|
||||||
|
if alpha != 0:
|
||||||
|
warnings.warn("Non-default `alpha` with no length penalty. "
|
||||||
|
"`alpha` has no effect.")
|
||||||
|
else:
|
||||||
|
# using some length penalty
|
||||||
|
if length_penalty == "wu" and alpha == 0.:
|
||||||
|
warnings.warn("Using length penalty Wu with alpha==0 "
|
||||||
|
"is equivalent to using length penalty none.")
|
||||||
|
if coverage_penalty is None or coverage_penalty == "none":
|
||||||
|
if beta != 0:
|
||||||
|
warnings.warn("Non-default `beta` with no coverage penalty. "
|
||||||
|
"`beta` has no effect.")
|
||||||
|
else:
|
||||||
|
# using some coverage penalty
|
||||||
|
if beta == 0.:
|
||||||
|
warnings.warn("Non-default coverage penalty with beta==0 "
|
||||||
|
"is equivalent to using coverage penalty none.")
|
||||||
|
|
|
@ -9,7 +9,6 @@ class DecodeStrategy(object):
|
||||||
bos (int): Magic integer in output vocab.
|
bos (int): Magic integer in output vocab.
|
||||||
eos (int): Magic integer in output vocab.
|
eos (int): Magic integer in output vocab.
|
||||||
batch_size (int): Current batch size.
|
batch_size (int): Current batch size.
|
||||||
device (torch.device or str): Device for memory bank (encoder).
|
|
||||||
parallel_paths (int): Decoding strategies like beam search
|
parallel_paths (int): Decoding strategies like beam search
|
||||||
use parallel paths. Each batch is repeated ``parallel_paths``
|
use parallel paths. Each batch is repeated ``parallel_paths``
|
||||||
times in relevant state tensors.
|
times in relevant state tensors.
|
||||||
|
@ -54,7 +53,7 @@ class DecodeStrategy(object):
|
||||||
done (bool): See above.
|
done (bool): See above.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pad, bos, eos, batch_size, device, parallel_paths,
|
def __init__(self, pad, bos, eos, batch_size, parallel_paths,
|
||||||
min_length, block_ngram_repeat, exclusion_tokens,
|
min_length, block_ngram_repeat, exclusion_tokens,
|
||||||
return_attention, max_length):
|
return_attention, max_length):
|
||||||
|
|
||||||
|
@ -63,27 +62,43 @@ class DecodeStrategy(object):
|
||||||
self.bos = bos
|
self.bos = bos
|
||||||
self.eos = eos
|
self.eos = eos
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.parallel_paths = parallel_paths
|
||||||
# result caching
|
# result caching
|
||||||
self.predictions = [[] for _ in range(batch_size)]
|
self.predictions = [[] for _ in range(batch_size)]
|
||||||
self.scores = [[] for _ in range(batch_size)]
|
self.scores = [[] for _ in range(batch_size)]
|
||||||
self.attention = [[] for _ in range(batch_size)]
|
self.attention = [[] for _ in range(batch_size)]
|
||||||
|
|
||||||
self.alive_seq = torch.full(
|
|
||||||
[batch_size * parallel_paths, 1], self.bos,
|
|
||||||
dtype=torch.long, device=device)
|
|
||||||
self.is_finished = torch.zeros(
|
|
||||||
[batch_size, parallel_paths],
|
|
||||||
dtype=torch.uint8, device=device)
|
|
||||||
self.alive_attn = None
|
self.alive_attn = None
|
||||||
|
|
||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
|
||||||
self.block_ngram_repeat = block_ngram_repeat
|
self.block_ngram_repeat = block_ngram_repeat
|
||||||
|
n_paths = batch_size * parallel_paths
|
||||||
|
self.forbidden_tokens = [dict() for _ in range(n_paths)]
|
||||||
|
|
||||||
self.exclusion_tokens = exclusion_tokens
|
self.exclusion_tokens = exclusion_tokens
|
||||||
self.return_attention = return_attention
|
self.return_attention = return_attention
|
||||||
|
|
||||||
self.done = False
|
self.done = False
|
||||||
|
|
||||||
|
def initialize(self, memory_bank, src_lengths, src_map=None, device=None):
|
||||||
|
"""DecodeStrategy subclasses should override :func:`initialize()`.
|
||||||
|
|
||||||
|
`initialize` should be called before all actions.
|
||||||
|
used to prepare necessary ingredients for decode.
|
||||||
|
"""
|
||||||
|
if device is None:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
self.alive_seq = torch.full(
|
||||||
|
[self.batch_size * self.parallel_paths, 1], self.bos,
|
||||||
|
dtype=torch.long, device=device)
|
||||||
|
self.is_finished = torch.zeros(
|
||||||
|
[self.batch_size, self.parallel_paths],
|
||||||
|
dtype=torch.uint8, device=device)
|
||||||
|
return None, memory_bank, src_lengths, src_map
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.alive_seq.shape[1]
|
return self.alive_seq.shape[1]
|
||||||
|
|
||||||
|
@ -98,25 +113,75 @@ class DecodeStrategy(object):
|
||||||
self.is_finished.fill_(1)
|
self.is_finished.fill_(1)
|
||||||
|
|
||||||
def block_ngram_repeats(self, log_probs):
|
def block_ngram_repeats(self, log_probs):
|
||||||
cur_len = len(self)
|
"""
|
||||||
if self.block_ngram_repeat > 0 and cur_len > 1:
|
We prevent the beam from going in any direction that would repeat any
|
||||||
|
ngram of size <block_ngram_repeat> more thant once.
|
||||||
|
|
||||||
|
The way we do it: we maintain a list of all ngrams of size
|
||||||
|
<block_ngram_repeat> that is updated each time the beam advances, and
|
||||||
|
manually put any token that would lead to a repeated ngram to 0.
|
||||||
|
|
||||||
|
This improves on the previous version's complexity:
|
||||||
|
- previous version's complexity: batch_size * beam_size * len(self)
|
||||||
|
- current version's complexity: batch_size * beam_size
|
||||||
|
|
||||||
|
This improves on the previous version's accuracy;
|
||||||
|
- Previous version blocks the whole beam, whereas here we only
|
||||||
|
block specific tokens.
|
||||||
|
- Before the translation would fail when all beams contained
|
||||||
|
repeated ngrams. This is sure to never happen here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# we don't block nothing if the user doesn't want it
|
||||||
|
if self.block_ngram_repeat <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# we can't block nothing beam's too short
|
||||||
|
if len(self) < self.block_ngram_repeat:
|
||||||
|
return
|
||||||
|
|
||||||
|
n = self.block_ngram_repeat - 1
|
||||||
for path_idx in range(self.alive_seq.shape[0]):
|
for path_idx in range(self.alive_seq.shape[0]):
|
||||||
# skip BOS
|
# we check paths one by one
|
||||||
hyp = self.alive_seq[path_idx, 1:]
|
|
||||||
ngrams = set()
|
current_ngram = tuple(self.alive_seq[path_idx, -n:].tolist())
|
||||||
fail = False
|
forbidden_tokens = self.forbidden_tokens[path_idx].get(
|
||||||
gram = []
|
current_ngram, None)
|
||||||
for i in range(cur_len - 1):
|
if forbidden_tokens is not None:
|
||||||
# Last n tokens, n = block_ngram_repeat
|
log_probs[path_idx, list(forbidden_tokens)] = -10e20
|
||||||
gram = (gram + [hyp[i].item()])[-self.block_ngram_repeat:]
|
|
||||||
# skip the blocking if any token in gram is excluded
|
def maybe_update_forbidden_tokens(self):
|
||||||
if set(gram) & self.exclusion_tokens:
|
"""We complete and reorder the list of forbidden_tokens"""
|
||||||
|
|
||||||
|
# we don't forbid nothing if the user doesn't want it
|
||||||
|
if self.block_ngram_repeat <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# we can't forbid nothing if beam's too short
|
||||||
|
if len(self) < self.block_ngram_repeat:
|
||||||
|
return
|
||||||
|
|
||||||
|
n = self.block_ngram_repeat
|
||||||
|
|
||||||
|
forbidden_tokens = list()
|
||||||
|
for path_idx, seq in zip(self.select_indices, self.alive_seq):
|
||||||
|
|
||||||
|
# Reordering forbidden_tokens following beam selection
|
||||||
|
# We rebuild a dict to ensure we get the value and not the pointer
|
||||||
|
forbidden_tokens.append(
|
||||||
|
dict(self.forbidden_tokens[path_idx]))
|
||||||
|
|
||||||
|
# Grabing the newly selected tokens and associated ngram
|
||||||
|
current_ngram = tuple(seq[-n:].tolist())
|
||||||
|
|
||||||
|
# skip the blocking if any token in current_ngram is excluded
|
||||||
|
if set(current_ngram) & self.exclusion_tokens:
|
||||||
continue
|
continue
|
||||||
if tuple(gram) in ngrams:
|
|
||||||
fail = True
|
forbidden_tokens[-1].setdefault(current_ngram[:-1], set())
|
||||||
ngrams.add(tuple(gram))
|
forbidden_tokens[-1][current_ngram[:-1]].add(current_ngram[-1])
|
||||||
if fail:
|
|
||||||
log_probs[path_idx] = -10e20
|
self.forbidden_tokens = forbidden_tokens
|
||||||
|
|
||||||
def advance(self, log_probs, attn):
|
def advance(self, log_probs, attn):
|
||||||
"""DecodeStrategy subclasses should override :func:`advance()`.
|
"""DecodeStrategy subclasses should override :func:`advance()`.
|
||||||
|
|
|
@ -56,7 +56,7 @@ def sample_with_temperature(logits, sampling_temp, keep_topk):
|
||||||
return topk_ids, topk_scores
|
return topk_ids, topk_scores
|
||||||
|
|
||||||
|
|
||||||
class RandomSampling(DecodeStrategy):
|
class GreedySearch(DecodeStrategy):
|
||||||
"""Select next tokens randomly from the top k possible next tokens.
|
"""Select next tokens randomly from the top k possible next tokens.
|
||||||
|
|
||||||
The ``scores`` attribute's lists are the score, after applying temperature,
|
The ``scores`` attribute's lists are the score, after applying temperature,
|
||||||
|
@ -68,7 +68,6 @@ class RandomSampling(DecodeStrategy):
|
||||||
bos (int): See base.
|
bos (int): See base.
|
||||||
eos (int): See base.
|
eos (int): See base.
|
||||||
batch_size (int): See base.
|
batch_size (int): See base.
|
||||||
device (torch.device or str): See base ``device``.
|
|
||||||
min_length (int): See base.
|
min_length (int): See base.
|
||||||
max_length (int): See base.
|
max_length (int): See base.
|
||||||
block_ngram_repeat (int): See base.
|
block_ngram_repeat (int): See base.
|
||||||
|
@ -76,30 +75,49 @@ class RandomSampling(DecodeStrategy):
|
||||||
return_attention (bool): See base.
|
return_attention (bool): See base.
|
||||||
max_length (int): See base.
|
max_length (int): See base.
|
||||||
sampling_temp (float): See
|
sampling_temp (float): See
|
||||||
:func:`~onmt.translate.random_sampling.sample_with_temperature()`.
|
:func:`~onmt.translate.greedy_search.sample_with_temperature()`.
|
||||||
keep_topk (int): See
|
keep_topk (int): See
|
||||||
:func:`~onmt.translate.random_sampling.sample_with_temperature()`.
|
:func:`~onmt.translate.greedy_search.sample_with_temperature()`.
|
||||||
memory_length (LongTensor): Lengths of encodings. Used for
|
|
||||||
masking attention.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, pad, bos, eos, batch_size, device,
|
def __init__(self, pad, bos, eos, batch_size, min_length,
|
||||||
min_length, block_ngram_repeat, exclusion_tokens,
|
block_ngram_repeat, exclusion_tokens, return_attention,
|
||||||
return_attention, max_length, sampling_temp, keep_topk,
|
max_length, sampling_temp, keep_topk):
|
||||||
memory_length):
|
assert block_ngram_repeat == 0
|
||||||
super(RandomSampling, self).__init__(
|
super(GreedySearch, self).__init__(
|
||||||
pad, bos, eos, batch_size, device, 1,
|
pad, bos, eos, batch_size, 1, min_length, block_ngram_repeat,
|
||||||
min_length, block_ngram_repeat, exclusion_tokens,
|
exclusion_tokens, return_attention, max_length)
|
||||||
return_attention, max_length)
|
|
||||||
self.sampling_temp = sampling_temp
|
self.sampling_temp = sampling_temp
|
||||||
self.keep_topk = keep_topk
|
self.keep_topk = keep_topk
|
||||||
self.topk_scores = None
|
self.topk_scores = None
|
||||||
self.memory_length = memory_length
|
|
||||||
self.batch_size = batch_size
|
def initialize(self, memory_bank, src_lengths, src_map=None, device=None):
|
||||||
self.select_indices = torch.arange(self.batch_size,
|
"""Initialize for decoding."""
|
||||||
dtype=torch.long, device=device)
|
fn_map_state = None
|
||||||
self.original_batch_idx = torch.arange(self.batch_size,
|
|
||||||
dtype=torch.long, device=device)
|
if isinstance(memory_bank, tuple):
|
||||||
|
mb_device = memory_bank[0].device
|
||||||
|
else:
|
||||||
|
mb_device = memory_bank.device
|
||||||
|
if device is None:
|
||||||
|
device = mb_device
|
||||||
|
|
||||||
|
self.memory_lengths = src_lengths
|
||||||
|
super(GreedySearch, self).initialize(
|
||||||
|
memory_bank, src_lengths, src_map, device)
|
||||||
|
self.select_indices = torch.arange(
|
||||||
|
self.batch_size, dtype=torch.long, device=device)
|
||||||
|
self.original_batch_idx = torch.arange(
|
||||||
|
self.batch_size, dtype=torch.long, device=device)
|
||||||
|
return fn_map_state, memory_bank, self.memory_lengths, src_map
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_predictions(self):
|
||||||
|
return self.alive_seq[:, -1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_offset(self):
|
||||||
|
return self.select_indices
|
||||||
|
|
||||||
def advance(self, log_probs, attn):
|
def advance(self, log_probs, attn):
|
||||||
"""Select next tokens randomly from the top k possible next tokens.
|
"""Select next tokens randomly from the top k possible next tokens.
|
||||||
|
@ -138,7 +156,7 @@ class RandomSampling(DecodeStrategy):
|
||||||
self.scores[b_orig].append(self.topk_scores[b, 0])
|
self.scores[b_orig].append(self.topk_scores[b, 0])
|
||||||
self.predictions[b_orig].append(self.alive_seq[b, 1:])
|
self.predictions[b_orig].append(self.alive_seq[b, 1:])
|
||||||
self.attention[b_orig].append(
|
self.attention[b_orig].append(
|
||||||
self.alive_attn[:, b, :self.memory_length[b]]
|
self.alive_attn[:, b, :self.memory_lengths[b]]
|
||||||
if self.alive_attn is not None else [])
|
if self.alive_attn is not None else [])
|
||||||
self.done = self.is_finished.all()
|
self.done = self.is_finished.all()
|
||||||
if self.done:
|
if self.done:
|
|
@ -3,23 +3,21 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import codecs
|
import codecs
|
||||||
import os
|
import os
|
||||||
import math
|
|
||||||
import time
|
import time
|
||||||
from itertools import count
|
import numpy as np
|
||||||
|
from itertools import count, zip_longest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import onmt.model_builder
|
import onmt.model_builder
|
||||||
import onmt.translate.beam
|
|
||||||
import onmt.inputters as inputters
|
import onmt.inputters as inputters
|
||||||
import onmt.decoders.ensemble
|
import onmt.decoders.ensemble
|
||||||
from onmt.translate.beam_search import BeamSearch
|
from onmt.translate.beam_search import BeamSearch
|
||||||
from onmt.translate.random_sampling import RandomSampling
|
from onmt.translate.greedy_search import GreedySearch
|
||||||
from onmt.utils.misc import tile, set_random_seed
|
from onmt.utils.misc import tile, set_random_seed, report_matrix
|
||||||
|
from onmt.utils.alignment import extract_alignment, build_align_pharaoh
|
||||||
from onmt.modules.copy_generator import collapse_copy_scores
|
from onmt.modules.copy_generator import collapse_copy_scores
|
||||||
|
|
||||||
import editdistance
|
|
||||||
|
|
||||||
|
|
||||||
def build_translator(opt, report_score=True, logger=None, out_file=None):
|
def build_translator(opt, report_score=True, logger=None, out_file=None):
|
||||||
if out_file is None:
|
if out_file is None:
|
||||||
|
@ -38,12 +36,32 @@ def build_translator(opt, report_score=True, logger=None, out_file=None):
|
||||||
model_opt,
|
model_opt,
|
||||||
global_scorer=scorer,
|
global_scorer=scorer,
|
||||||
out_file=out_file,
|
out_file=out_file,
|
||||||
|
report_align=opt.report_align,
|
||||||
report_score=report_score,
|
report_score=report_score,
|
||||||
logger=logger
|
logger=logger
|
||||||
)
|
)
|
||||||
return translator
|
return translator
|
||||||
|
|
||||||
|
|
||||||
|
def max_tok_len(new, count, sofar):
|
||||||
|
"""
|
||||||
|
In token batching scheme, the number of sequences is limited
|
||||||
|
such that the total number of src/tgt tokens (including padding)
|
||||||
|
in a batch <= batch_size
|
||||||
|
"""
|
||||||
|
# Maintains the longest src and tgt length in the current batch
|
||||||
|
global max_src_in_batch # this is a hack
|
||||||
|
# Reset current longest length at a new batch (count=1)
|
||||||
|
if count == 1:
|
||||||
|
max_src_in_batch = 0
|
||||||
|
# max_tgt_in_batch = 0
|
||||||
|
# Src: [<bos> w1 ... wN <eos>]
|
||||||
|
max_src_in_batch = max(max_src_in_batch, len(new.src[0]) + 2)
|
||||||
|
# Tgt: [w1 ... wM <eos>]
|
||||||
|
src_elements = count * max_src_in_batch
|
||||||
|
return src_elements
|
||||||
|
|
||||||
|
|
||||||
class Translator(object):
|
class Translator(object):
|
||||||
"""Translate a batch of sentences with a saved model.
|
"""Translate a batch of sentences with a saved model.
|
||||||
|
|
||||||
|
@ -61,9 +79,9 @@ class Translator(object):
|
||||||
:class:`onmt.translate.decode_strategy.DecodeStrategy`.
|
:class:`onmt.translate.decode_strategy.DecodeStrategy`.
|
||||||
beam_size (int): Number of beams.
|
beam_size (int): Number of beams.
|
||||||
random_sampling_topk (int): See
|
random_sampling_topk (int): See
|
||||||
:class:`onmt.translate.random_sampling.RandomSampling`.
|
:class:`onmt.translate.greedy_search.GreedySearch`.
|
||||||
random_sampling_temp (int): See
|
random_sampling_temp (int): See
|
||||||
:class:`onmt.translate.random_sampling.RandomSampling`.
|
:class:`onmt.translate.greedy_search.GreedySearch`.
|
||||||
stepwise_penalty (bool): Whether coverage penalty is applied every step
|
stepwise_penalty (bool): Whether coverage penalty is applied every step
|
||||||
or not.
|
or not.
|
||||||
dump_beam (bool): Debugging option.
|
dump_beam (bool): Debugging option.
|
||||||
|
@ -74,8 +92,6 @@ class Translator(object):
|
||||||
replace_unk (bool): Replace unknown token.
|
replace_unk (bool): Replace unknown token.
|
||||||
data_type (str): Source data type.
|
data_type (str): Source data type.
|
||||||
verbose (bool): Print/log every translation.
|
verbose (bool): Print/log every translation.
|
||||||
report_bleu (bool): Print/log Bleu metric.
|
|
||||||
report_rouge (bool): Print/log Rouge metric.
|
|
||||||
report_time (bool): Print/log total time/frequency.
|
report_time (bool): Print/log total time/frequency.
|
||||||
copy_attn (bool): Use copy attention.
|
copy_attn (bool): Use copy attention.
|
||||||
global_scorer (onmt.translate.GNMTGlobalScorer): Translation
|
global_scorer (onmt.translate.GNMTGlobalScorer): Translation
|
||||||
|
@ -104,14 +120,14 @@ class Translator(object):
|
||||||
block_ngram_repeat=0,
|
block_ngram_repeat=0,
|
||||||
ignore_when_blocking=frozenset(),
|
ignore_when_blocking=frozenset(),
|
||||||
replace_unk=False,
|
replace_unk=False,
|
||||||
|
phrase_table="",
|
||||||
data_type="text",
|
data_type="text",
|
||||||
verbose=False,
|
verbose=False,
|
||||||
report_bleu=False,
|
|
||||||
report_rouge=False,
|
|
||||||
report_time=False,
|
report_time=False,
|
||||||
copy_attn=False,
|
copy_attn=False,
|
||||||
global_scorer=None,
|
global_scorer=None,
|
||||||
out_file=None,
|
out_file=None,
|
||||||
|
report_align=False,
|
||||||
report_score=True,
|
report_score=True,
|
||||||
logger=None,
|
logger=None,
|
||||||
seed=-1):
|
seed=-1):
|
||||||
|
@ -151,10 +167,9 @@ class Translator(object):
|
||||||
if self.replace_unk and not self.model.decoder.attentional:
|
if self.replace_unk and not self.model.decoder.attentional:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"replace_unk requires an attentional decoder.")
|
"replace_unk requires an attentional decoder.")
|
||||||
|
self.phrase_table = phrase_table
|
||||||
self.data_type = data_type
|
self.data_type = data_type
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.report_bleu = report_bleu
|
|
||||||
self.report_rouge = report_rouge
|
|
||||||
self.report_time = report_time
|
self.report_time = report_time
|
||||||
|
|
||||||
self.copy_attn = copy_attn
|
self.copy_attn = copy_attn
|
||||||
|
@ -165,6 +180,7 @@ class Translator(object):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Coverage penalty requires an attentional decoder.")
|
"Coverage penalty requires an attentional decoder.")
|
||||||
self.out_file = out_file
|
self.out_file = out_file
|
||||||
|
self.report_align = report_align
|
||||||
self.report_score = report_score
|
self.report_score = report_score
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
|
@ -192,6 +208,7 @@ class Translator(object):
|
||||||
model_opt,
|
model_opt,
|
||||||
global_scorer=None,
|
global_scorer=None,
|
||||||
out_file=None,
|
out_file=None,
|
||||||
|
report_align=False,
|
||||||
report_score=True,
|
report_score=True,
|
||||||
logger=None):
|
logger=None):
|
||||||
"""Alternate constructor.
|
"""Alternate constructor.
|
||||||
|
@ -207,6 +224,7 @@ class Translator(object):
|
||||||
:func:`__init__()`..
|
:func:`__init__()`..
|
||||||
out_file (TextIO or codecs.StreamReaderWriter): See
|
out_file (TextIO or codecs.StreamReaderWriter): See
|
||||||
:func:`__init__()`.
|
:func:`__init__()`.
|
||||||
|
report_align (bool) : See :func:`__init__()`.
|
||||||
report_score (bool) : See :func:`__init__()`.
|
report_score (bool) : See :func:`__init__()`.
|
||||||
logger (logging.Logger or NoneType): See :func:`__init__()`.
|
logger (logging.Logger or NoneType): See :func:`__init__()`.
|
||||||
"""
|
"""
|
||||||
|
@ -231,14 +249,14 @@ class Translator(object):
|
||||||
block_ngram_repeat=opt.block_ngram_repeat,
|
block_ngram_repeat=opt.block_ngram_repeat,
|
||||||
ignore_when_blocking=set(opt.ignore_when_blocking),
|
ignore_when_blocking=set(opt.ignore_when_blocking),
|
||||||
replace_unk=opt.replace_unk,
|
replace_unk=opt.replace_unk,
|
||||||
|
phrase_table=opt.phrase_table,
|
||||||
data_type=opt.data_type,
|
data_type=opt.data_type,
|
||||||
verbose=opt.verbose,
|
verbose=opt.verbose,
|
||||||
report_bleu=opt.report_bleu,
|
|
||||||
report_rouge=opt.report_rouge,
|
|
||||||
report_time=opt.report_time,
|
report_time=opt.report_time,
|
||||||
copy_attn=model_opt.copy_attn,
|
copy_attn=model_opt.copy_attn,
|
||||||
global_scorer=global_scorer,
|
global_scorer=global_scorer,
|
||||||
out_file=out_file,
|
out_file=out_file,
|
||||||
|
report_align=report_align,
|
||||||
report_score=report_score,
|
report_score=report_score,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
seed=opt.seed)
|
seed=opt.seed)
|
||||||
|
@ -266,7 +284,10 @@ class Translator(object):
|
||||||
tgt=None,
|
tgt=None,
|
||||||
src_dir=None,
|
src_dir=None,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
|
batch_type="sents",
|
||||||
attn_debug=False,
|
attn_debug=False,
|
||||||
|
align_debug=False,
|
||||||
|
phrase_table="",
|
||||||
partial=None,
|
partial=None,
|
||||||
dymax_len=None):
|
dymax_len=None):
|
||||||
"""Translate content of ``src`` and get gold scores from ``tgt``.
|
"""Translate content of ``src`` and get gold scores from ``tgt``.
|
||||||
|
@ -278,6 +299,7 @@ class Translator(object):
|
||||||
for certain types of data).
|
for certain types of data).
|
||||||
batch_size (int): size of examples per mini-batch
|
batch_size (int): size of examples per mini-batch
|
||||||
attn_debug (bool): enables the attention logging
|
attn_debug (bool): enables the attention logging
|
||||||
|
align_debug (bool): enables the word alignment logging
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(`list`, `list`)
|
(`list`, `list`)
|
||||||
|
@ -327,12 +349,16 @@ class Translator(object):
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
raise ValueError("batch_size must be set")
|
raise ValueError("batch_size must be set")
|
||||||
|
|
||||||
|
src_data = {"reader": self.src_reader, "data": src, "dir": src_dir}
|
||||||
|
tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None}
|
||||||
|
_readers, _data, _dir = inputters.Dataset.config(
|
||||||
|
[('src', src_data), ('tgt', tgt_data)])
|
||||||
|
|
||||||
|
# corpus_id field is useless here
|
||||||
|
if self.fields.get("corpus_id", None) is not None:
|
||||||
|
self.fields.pop('corpus_id')
|
||||||
data = inputters.Dataset(
|
data = inputters.Dataset(
|
||||||
self.fields,
|
self.fields, readers=_readers, data=_data, dirs=_dir,
|
||||||
readers=([self.src_reader, self.tgt_reader]
|
|
||||||
if tgt else [self.src_reader]),
|
|
||||||
data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)],
|
|
||||||
dirs=[src_dir, None] if tgt else [src_dir],
|
|
||||||
sort_key=inputters.str2sortkey[self.data_type],
|
sort_key=inputters.str2sortkey[self.data_type],
|
||||||
filter_pred=self._filter_pred
|
filter_pred=self._filter_pred
|
||||||
)
|
)
|
||||||
|
@ -341,6 +367,7 @@ class Translator(object):
|
||||||
dataset=data,
|
dataset=data,
|
||||||
device=self._dev,
|
device=self._dev,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
batch_size_fn=max_tok_len if batch_type == "tokens" else None,
|
||||||
train=False,
|
train=False,
|
||||||
sort=False,
|
sort=False,
|
||||||
sort_within_batch=True,
|
sort_within_batch=True,
|
||||||
|
@ -348,7 +375,8 @@ class Translator(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
xlation_builder = onmt.translate.TranslationBuilder(
|
xlation_builder = onmt.translate.TranslationBuilder(
|
||||||
data, self.fields, self.n_best, self.replace_unk, tgt
|
data, self.fields, self.n_best, self.replace_unk, tgt,
|
||||||
|
self.phrase_table
|
||||||
)
|
)
|
||||||
|
|
||||||
# Statistics
|
# Statistics
|
||||||
|
@ -377,6 +405,14 @@ class Translator(object):
|
||||||
|
|
||||||
n_best_preds = [" ".join(pred)
|
n_best_preds = [" ".join(pred)
|
||||||
for pred in trans.pred_sents[:self.n_best]]
|
for pred in trans.pred_sents[:self.n_best]]
|
||||||
|
if self.report_align:
|
||||||
|
align_pharaohs = [build_align_pharaoh(align) for align
|
||||||
|
in trans.word_aligns[:self.n_best]]
|
||||||
|
n_best_preds_align = [" ".join(align) for align
|
||||||
|
in align_pharaohs]
|
||||||
|
n_best_preds = [pred + " ||| " + align
|
||||||
|
for pred, align in zip(
|
||||||
|
n_best_preds, n_best_preds_align)]
|
||||||
all_predictions += [n_best_preds]
|
all_predictions += [n_best_preds]
|
||||||
self.out_file.write('\n'.join(n_best_preds) + '\n')
|
self.out_file.write('\n'.join(n_best_preds) + '\n')
|
||||||
self.out_file.flush()
|
self.out_file.flush()
|
||||||
|
@ -397,27 +433,28 @@ class Translator(object):
|
||||||
srcs = trans.src_raw
|
srcs = trans.src_raw
|
||||||
else:
|
else:
|
||||||
srcs = [str(item) for item in range(len(attns[0]))]
|
srcs = [str(item) for item in range(len(attns[0]))]
|
||||||
header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
|
output = report_matrix(srcs, preds, attns)
|
||||||
row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
|
|
||||||
output = header_format.format("", *srcs) + '\n'
|
if self.logger:
|
||||||
covatn = []
|
self.logger.info(output)
|
||||||
covatn2d = []
|
else:
|
||||||
print(len(preds),len(attns))
|
os.write(1, output.encode('utf-8'))
|
||||||
for word, row in zip(preds, attns):
|
|
||||||
max_index = row.index(max(row))
|
if align_debug:
|
||||||
# if not covatn:
|
if trans.gold_sent is not None:
|
||||||
# covatn = [0]*len(row)
|
tgts = trans.gold_sent
|
||||||
# covatn[max_index] += 1
|
else:
|
||||||
covatn2d.append(row)
|
tgts = trans.pred_sents[0]
|
||||||
row_format = row_format.replace(
|
align = trans.word_aligns[0].tolist()
|
||||||
"{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
|
if self.data_type == 'text':
|
||||||
row_format = row_format.replace(
|
srcs = trans.src_raw
|
||||||
"{:*>10.7f} ", "{:>10.7f} ", max_index)
|
else:
|
||||||
output += row_format.format(word, *row) + '\n'
|
srcs = [str(item) for item in range(len(align[0]))]
|
||||||
row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
|
output = report_matrix(srcs, tgts, align)
|
||||||
print(output)
|
if self.logger:
|
||||||
# print(covatn2d)
|
self.logger.info(output)
|
||||||
# os.write(1, output.encode('utf-8'))
|
else:
|
||||||
|
os.write(1, output.encode('utf-8'))
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
|
@ -429,12 +466,6 @@ class Translator(object):
|
||||||
msg = self._report_score('GOLD', gold_score_total,
|
msg = self._report_score('GOLD', gold_score_total,
|
||||||
gold_words_total)
|
gold_words_total)
|
||||||
self._log(msg)
|
self._log(msg)
|
||||||
if self.report_bleu:
|
|
||||||
msg = self._report_bleu(tgt)
|
|
||||||
self._log(msg)
|
|
||||||
if self.report_rouge:
|
|
||||||
msg = self._report_rouge(tgt)
|
|
||||||
self._log(msg)
|
|
||||||
|
|
||||||
if self.report_time:
|
if self.report_time:
|
||||||
total_time = end_time - start_time
|
total_time = end_time - start_time
|
||||||
|
@ -448,124 +479,126 @@ class Translator(object):
|
||||||
import json
|
import json
|
||||||
json.dump(self.translator.beam_accum,
|
json.dump(self.translator.beam_accum,
|
||||||
codecs.open(self.dump_beam, 'w', 'utf-8'))
|
codecs.open(self.dump_beam, 'w', 'utf-8'))
|
||||||
|
|
||||||
if attn_debug:
|
if attn_debug:
|
||||||
return all_scores, all_predictions, covatn2d
|
return all_scores, all_predictions, attns
|
||||||
else:
|
else:
|
||||||
return all_scores, all_predictions
|
return all_scores, all_predictions
|
||||||
|
|
||||||
def _translate_random_sampling(
|
def _align_pad_prediction(self, predictions, bos, pad):
|
||||||
self,
|
"""
|
||||||
batch,
|
Padding predictions in batch and add BOS.
|
||||||
src_vocabs,
|
|
||||||
max_length,
|
|
||||||
min_length=0,
|
|
||||||
sampling_temp=1.0,
|
|
||||||
keep_topk=-1,
|
|
||||||
return_attention=False):
|
|
||||||
"""Alternative to beam search. Do random sampling at each step."""
|
|
||||||
|
|
||||||
assert self.beam_size == 1
|
Args:
|
||||||
|
predictions (List[List[Tensor]]): `(batch, n_best,)`, for each src
|
||||||
|
sequence contain n_best tgt predictions all of which ended with
|
||||||
|
eos id.
|
||||||
|
bos (int): bos index to be used.
|
||||||
|
pad (int): pad index to be used.
|
||||||
|
|
||||||
# TODO: support these blacklisted features.
|
Return:
|
||||||
assert self.block_ngram_repeat == 0
|
batched_nbest_predict (torch.LongTensor): `(batch, n_best, tgt_l)`
|
||||||
|
"""
|
||||||
|
dtype, device = predictions[0][0].dtype, predictions[0][0].device
|
||||||
|
flatten_tgt = [best.tolist() for bests in predictions
|
||||||
|
for best in bests]
|
||||||
|
paded_tgt = torch.tensor(
|
||||||
|
list(zip_longest(*flatten_tgt, fillvalue=pad)),
|
||||||
|
dtype=dtype, device=device).T
|
||||||
|
bos_tensor = torch.full([paded_tgt.size(0), 1], bos,
|
||||||
|
dtype=dtype, device=device)
|
||||||
|
full_tgt = torch.cat((bos_tensor, paded_tgt), dim=-1)
|
||||||
|
batched_nbest_predict = full_tgt.view(
|
||||||
|
len(predictions), -1, full_tgt.size(-1)) # (batch, n_best, tgt_l)
|
||||||
|
return batched_nbest_predict
|
||||||
|
|
||||||
batch_size = batch.batch_size
|
def _align_forward(self, batch, predictions):
|
||||||
|
"""
|
||||||
|
For a batch of input and its prediction, return a list of batch predict
|
||||||
|
alignment src indice Tensor in size ``(batch, n_best,)``.
|
||||||
|
"""
|
||||||
|
# (0) add BOS and padding to tgt prediction
|
||||||
|
if hasattr(batch, 'tgt'):
|
||||||
|
batch_tgt_idxs = batch.tgt.transpose(1, 2).transpose(0, 2)
|
||||||
|
else:
|
||||||
|
batch_tgt_idxs = self._align_pad_prediction(
|
||||||
|
predictions, bos=self._tgt_bos_idx, pad=self._tgt_pad_idx)
|
||||||
|
tgt_mask = (batch_tgt_idxs.eq(self._tgt_pad_idx) |
|
||||||
|
batch_tgt_idxs.eq(self._tgt_eos_idx) |
|
||||||
|
batch_tgt_idxs.eq(self._tgt_bos_idx))
|
||||||
|
|
||||||
# Encoder forward.
|
n_best = batch_tgt_idxs.size(1)
|
||||||
|
# (1) Encoder forward.
|
||||||
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
|
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
|
||||||
|
|
||||||
|
# (2) Repeat src objects `n_best` times.
|
||||||
|
# We use batch_size x n_best, get ``(src_len, batch * n_best, nfeat)``
|
||||||
|
src = tile(src, n_best, dim=1)
|
||||||
|
enc_states = tile(enc_states, n_best, dim=1)
|
||||||
|
if isinstance(memory_bank, tuple):
|
||||||
|
memory_bank = tuple(tile(x, n_best, dim=1) for x in memory_bank)
|
||||||
|
else:
|
||||||
|
memory_bank = tile(memory_bank, n_best, dim=1)
|
||||||
|
src_lengths = tile(src_lengths, n_best) # ``(batch * n_best,)``
|
||||||
|
|
||||||
|
# (3) Init decoder with n_best src,
|
||||||
self.model.decoder.init_state(src, memory_bank, enc_states)
|
self.model.decoder.init_state(src, memory_bank, enc_states)
|
||||||
|
# reshape tgt to ``(len, batch * n_best, nfeat)``
|
||||||
|
tgt = batch_tgt_idxs.view(-1, batch_tgt_idxs.size(-1)).T.unsqueeze(-1)
|
||||||
|
dec_in = tgt[:-1] # exclude last target from inputs
|
||||||
|
_, attns = self.model.decoder(
|
||||||
|
dec_in, memory_bank, memory_lengths=src_lengths, with_align=True)
|
||||||
|
|
||||||
use_src_map = self.copy_attn
|
alignment_attn = attns["align"] # ``(B, tgt_len-1, src_len)``
|
||||||
|
# masked_select
|
||||||
results = {
|
align_tgt_mask = tgt_mask.view(-1, tgt_mask.size(-1))
|
||||||
"predictions": None,
|
prediction_mask = align_tgt_mask[:, 1:] # exclude bos to match pred
|
||||||
"scores": None,
|
# get aligned src id for each prediction's valid tgt tokens
|
||||||
"attention": None,
|
alignement = extract_alignment(
|
||||||
"batch": batch,
|
alignment_attn, prediction_mask, src_lengths, n_best)
|
||||||
"gold_score": self._gold_score(
|
return alignement
|
||||||
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
|
|
||||||
enc_states, batch_size, src)}
|
|
||||||
|
|
||||||
memory_lengths = src_lengths
|
|
||||||
src_map = batch.src_map if use_src_map else None
|
|
||||||
|
|
||||||
if isinstance(memory_bank, tuple):
|
|
||||||
mb_device = memory_bank[0].device
|
|
||||||
else:
|
|
||||||
mb_device = memory_bank.device
|
|
||||||
|
|
||||||
random_sampler = RandomSampling(
|
|
||||||
self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx,
|
|
||||||
batch_size, mb_device, min_length, self.block_ngram_repeat,
|
|
||||||
self._exclusion_idxs, return_attention, self.max_length,
|
|
||||||
sampling_temp, keep_topk, memory_lengths)
|
|
||||||
|
|
||||||
for step in range(max_length):
|
|
||||||
# Shape: (1, B, 1)
|
|
||||||
decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1)
|
|
||||||
|
|
||||||
log_probs, attn = self._decode_and_generate(
|
|
||||||
decoder_input,
|
|
||||||
memory_bank,
|
|
||||||
batch,
|
|
||||||
src_vocabs,
|
|
||||||
memory_lengths=memory_lengths,
|
|
||||||
src_map=src_map,
|
|
||||||
step=step,
|
|
||||||
batch_offset=random_sampler.select_indices
|
|
||||||
)
|
|
||||||
|
|
||||||
random_sampler.advance(log_probs, attn)
|
|
||||||
any_batch_is_finished = random_sampler.is_finished.any()
|
|
||||||
if any_batch_is_finished:
|
|
||||||
random_sampler.update_finished()
|
|
||||||
if random_sampler.done:
|
|
||||||
break
|
|
||||||
|
|
||||||
if any_batch_is_finished:
|
|
||||||
select_indices = random_sampler.select_indices
|
|
||||||
|
|
||||||
# Reorder states.
|
|
||||||
if isinstance(memory_bank, tuple):
|
|
||||||
memory_bank = tuple(x.index_select(1, select_indices)
|
|
||||||
for x in memory_bank)
|
|
||||||
else:
|
|
||||||
memory_bank = memory_bank.index_select(1, select_indices)
|
|
||||||
|
|
||||||
memory_lengths = memory_lengths.index_select(0, select_indices)
|
|
||||||
|
|
||||||
if src_map is not None:
|
|
||||||
src_map = src_map.index_select(1, select_indices)
|
|
||||||
|
|
||||||
self.model.decoder.map_state(
|
|
||||||
lambda state, dim: state.index_select(dim, select_indices))
|
|
||||||
|
|
||||||
results["scores"] = random_sampler.scores
|
|
||||||
results["predictions"] = random_sampler.predictions
|
|
||||||
results["attention"] = random_sampler.attention
|
|
||||||
return results
|
|
||||||
|
|
||||||
def translate_batch(self, batch, src_vocabs, attn_debug):
|
def translate_batch(self, batch, src_vocabs, attn_debug):
|
||||||
"""Translate a batch of sentences."""
|
"""Translate a batch of sentences."""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.beam_size == 1:
|
if self.beam_size == 1:
|
||||||
return self._translate_random_sampling(
|
decode_strategy = GreedySearch(
|
||||||
batch,
|
pad=self._tgt_pad_idx,
|
||||||
src_vocabs,
|
bos=self._tgt_bos_idx,
|
||||||
self.max_length,
|
eos=self._tgt_eos_idx,
|
||||||
min_length=self.min_length,
|
batch_size=batch.batch_size,
|
||||||
|
min_length=self.min_length, max_length=self.max_length,
|
||||||
|
block_ngram_repeat=self.block_ngram_repeat,
|
||||||
|
exclusion_tokens=self._exclusion_idxs,
|
||||||
|
return_attention=attn_debug or self.replace_unk,
|
||||||
sampling_temp=self.random_sampling_temp,
|
sampling_temp=self.random_sampling_temp,
|
||||||
keep_topk=self.sample_from_topk,
|
keep_topk=self.sample_from_topk)
|
||||||
return_attention=attn_debug or self.replace_unk)
|
|
||||||
else:
|
else:
|
||||||
return self._translate_batch(
|
# TODO: support these blacklisted features
|
||||||
batch,
|
assert not self.dump_beam
|
||||||
src_vocabs,
|
max_length = self.max_length
|
||||||
self.max_length,
|
if self.dymax_len is not None:
|
||||||
min_length=self.min_length,
|
if self.partial:
|
||||||
ratio=self.ratio,
|
max_length = len(self.partial) + 3
|
||||||
|
if not self.partial and self.partialf:
|
||||||
|
max_length = 3
|
||||||
|
|
||||||
|
decode_strategy = BeamSearch(
|
||||||
|
self.beam_size,
|
||||||
|
batch_size=batch.batch_size,
|
||||||
|
pad=self._tgt_pad_idx,
|
||||||
|
bos=self._tgt_bos_idx,
|
||||||
|
eos=self._tgt_eos_idx,
|
||||||
n_best=self.n_best,
|
n_best=self.n_best,
|
||||||
return_attention=attn_debug or self.replace_unk)
|
global_scorer=self.global_scorer,
|
||||||
|
min_length=self.min_length, max_length=max_length,
|
||||||
|
return_attention=attn_debug or self.replace_unk,
|
||||||
|
block_ngram_repeat=self.block_ngram_repeat,
|
||||||
|
exclusion_tokens=self._exclusion_idxs,
|
||||||
|
stepwise_penalty=self.stepwise_penalty,
|
||||||
|
ratio=self.ratio)
|
||||||
|
return self._translate_batch_with_strategy(batch, src_vocabs,
|
||||||
|
decode_strategy)
|
||||||
|
|
||||||
def _run_encoder(self, batch):
|
def _run_encoder(self, batch):
|
||||||
src, src_lengths = batch.src if isinstance(batch.src, tuple) \
|
src, src_lengths = batch.src if isinstance(batch.src, tuple) \
|
||||||
|
@ -622,7 +655,8 @@ class Translator(object):
|
||||||
src_map)
|
src_map)
|
||||||
# here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
|
# here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
|
||||||
if batch_offset is None:
|
if batch_offset is None:
|
||||||
scores = scores.view(batch.batch_size, -1, scores.size(-1))
|
scores = scores.view(-1, batch.batch_size, scores.size(-1))
|
||||||
|
scores = scores.transpose(0, 1).contiguous()
|
||||||
else:
|
else:
|
||||||
scores = scores.view(-1, self.beam_size, scores.size(-1))
|
scores = scores.view(-1, self.beam_size, scores.size(-1))
|
||||||
scores = collapse_copy_scores(
|
scores = collapse_copy_scores(
|
||||||
|
@ -639,25 +673,25 @@ class Translator(object):
|
||||||
# or [ tgt_len, batch_size, vocab ] when full sentence
|
# or [ tgt_len, batch_size, vocab ] when full sentence
|
||||||
return log_probs, attn
|
return log_probs, attn
|
||||||
|
|
||||||
def _translate_batch(
|
def _translate_batch_with_strategy(
|
||||||
self,
|
self,
|
||||||
batch,
|
batch,
|
||||||
src_vocabs,
|
src_vocabs,
|
||||||
max_length,
|
decode_strategy):
|
||||||
min_length=0,
|
"""Translate a batch of sentences step by step using cache.
|
||||||
ratio=0.,
|
|
||||||
n_best=1,
|
Args:
|
||||||
return_attention=False):
|
batch: a batch of sentences, yield by data iterator.
|
||||||
# TODO: support these blacklisted features.
|
src_vocabs (list): list of torchtext.data.Vocab if can_copy.
|
||||||
assert not self.dump_beam
|
decode_strategy (DecodeStrategy): A decode strategy to use for
|
||||||
if self.dymax_len is not None:
|
generate translation step by step.
|
||||||
if self.partial:
|
|
||||||
max_length = len(self.partial) + 3
|
Returns:
|
||||||
if not self.partial and self.partialf:
|
results (dict): The translation results.
|
||||||
max_length = 3
|
"""
|
||||||
# (0) Prep the components of the search.
|
# (0) Prep the components of the search.
|
||||||
use_src_map = self.copy_attn
|
use_src_map = self.copy_attn
|
||||||
beam_size = self.beam_size
|
parallel_paths = decode_strategy.parallel_paths # beam_size
|
||||||
batch_size = batch.batch_size
|
batch_size = batch.batch_size
|
||||||
|
|
||||||
# (1) Run the encoder on the src.
|
# (1) Run the encoder on the src.
|
||||||
|
@ -673,42 +707,16 @@ class Translator(object):
|
||||||
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
|
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
|
||||||
enc_states, batch_size, src)}
|
enc_states, batch_size, src)}
|
||||||
|
|
||||||
# (2) Repeat src objects `beam_size` times.
|
# (2) prep decode_strategy. Possibly repeat src objects.
|
||||||
# We use batch_size x beam_size
|
src_map = batch.src_map if use_src_map else None
|
||||||
src_map = (tile(batch.src_map, beam_size, dim=1)
|
fn_map_state, memory_bank, memory_lengths, src_map = \
|
||||||
if use_src_map else None)
|
decode_strategy.initialize(memory_bank, src_lengths, src_map)
|
||||||
self.model.decoder.map_state(
|
if fn_map_state is not None:
|
||||||
lambda state, dim: tile(state, beam_size, dim=dim))
|
self.model.decoder.map_state(fn_map_state)
|
||||||
|
|
||||||
if isinstance(memory_bank, tuple):
|
# (3) Begin decoding step by step:
|
||||||
memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
|
for step in range(decode_strategy.max_length):
|
||||||
mb_device = memory_bank[0].device
|
decoder_input = decode_strategy.current_predictions.view(1, -1, 1)
|
||||||
else:
|
|
||||||
memory_bank = tile(memory_bank, beam_size, dim=1)
|
|
||||||
mb_device = memory_bank.device
|
|
||||||
memory_lengths = tile(src_lengths, beam_size)
|
|
||||||
|
|
||||||
# (0) pt 2, prep the beam object
|
|
||||||
beam = BeamSearch(
|
|
||||||
beam_size,
|
|
||||||
n_best=n_best,
|
|
||||||
batch_size=batch_size,
|
|
||||||
global_scorer=self.global_scorer,
|
|
||||||
pad=self._tgt_pad_idx,
|
|
||||||
eos=self._tgt_eos_idx,
|
|
||||||
bos=self._tgt_bos_idx,
|
|
||||||
min_length=min_length,
|
|
||||||
ratio=ratio,
|
|
||||||
max_length=max_length,
|
|
||||||
mb_device=mb_device,
|
|
||||||
return_attention=return_attention,
|
|
||||||
stepwise_penalty=self.stepwise_penalty,
|
|
||||||
block_ngram_repeat=self.block_ngram_repeat,
|
|
||||||
exclusion_tokens=self._exclusion_idxs,
|
|
||||||
memory_lengths=memory_lengths)
|
|
||||||
|
|
||||||
for step in range(max_length):
|
|
||||||
decoder_input = beam.current_predictions.view(1, -1, 1)
|
|
||||||
|
|
||||||
log_probs, attn = self._decode_and_generate(
|
log_probs, attn = self._decode_and_generate(
|
||||||
decoder_input,
|
decoder_input,
|
||||||
|
@ -718,18 +726,18 @@ class Translator(object):
|
||||||
memory_lengths=memory_lengths,
|
memory_lengths=memory_lengths,
|
||||||
src_map=src_map,
|
src_map=src_map,
|
||||||
step=step,
|
step=step,
|
||||||
batch_offset=beam._batch_offset)
|
batch_offset=decode_strategy.batch_offset)
|
||||||
|
|
||||||
beam.advance(log_probs, attn, self.partial, self.partialf)
|
decode_strategy.advance(log_probs, attn, self.partial, self.partialf)
|
||||||
any_beam_is_finished = beam.is_finished.any()
|
any_finished = decode_strategy.is_finished.any()
|
||||||
if any_beam_is_finished:
|
if any_finished:
|
||||||
beam.update_finished()
|
decode_strategy.update_finished()
|
||||||
if beam.done:
|
if decode_strategy.done:
|
||||||
break
|
break
|
||||||
|
|
||||||
select_indices = beam.current_origin
|
select_indices = decode_strategy.select_indices
|
||||||
|
|
||||||
if any_beam_is_finished:
|
if any_finished:
|
||||||
# Reorder states.
|
# Reorder states.
|
||||||
if isinstance(memory_bank, tuple):
|
if isinstance(memory_bank, tuple):
|
||||||
memory_bank = tuple(x.index_select(1, select_indices)
|
memory_bank = tuple(x.index_select(1, select_indices)
|
||||||
|
@ -742,107 +750,18 @@ class Translator(object):
|
||||||
if src_map is not None:
|
if src_map is not None:
|
||||||
src_map = src_map.index_select(1, select_indices)
|
src_map = src_map.index_select(1, select_indices)
|
||||||
|
|
||||||
|
if parallel_paths > 1 or any_finished:
|
||||||
self.model.decoder.map_state(
|
self.model.decoder.map_state(
|
||||||
lambda state, dim: state.index_select(dim, select_indices))
|
lambda state, dim: state.index_select(dim, select_indices))
|
||||||
|
|
||||||
results["scores"] = beam.scores
|
results["scores"] = decode_strategy.scores
|
||||||
results["predictions"] = beam.predictions
|
results["predictions"] = decode_strategy.predictions
|
||||||
results["attention"] = beam.attention
|
results["attention"] = decode_strategy.attention
|
||||||
return results
|
if self.report_align:
|
||||||
|
results["alignment"] = self._align_forward(
|
||||||
# This is left in the code for now, but unsued
|
batch, decode_strategy.predictions)
|
||||||
def _translate_batch_deprecated(self, batch, src_vocabs):
|
|
||||||
# (0) Prep each of the components of the search.
|
|
||||||
# And helper method for reducing verbosity.
|
|
||||||
use_src_map = self.copy_attn
|
|
||||||
beam_size = self.beam_size
|
|
||||||
batch_size = batch.batch_size
|
|
||||||
|
|
||||||
beam = [onmt.translate.Beam(
|
|
||||||
beam_size,
|
|
||||||
n_best=self.n_best,
|
|
||||||
cuda=self.cuda,
|
|
||||||
global_scorer=self.global_scorer,
|
|
||||||
pad=self._tgt_pad_idx,
|
|
||||||
eos=self._tgt_eos_idx,
|
|
||||||
bos=self._tgt_bos_idx,
|
|
||||||
min_length=self.min_length,
|
|
||||||
stepwise_penalty=self.stepwise_penalty,
|
|
||||||
block_ngram_repeat=self.block_ngram_repeat,
|
|
||||||
exclusion_tokens=self._exclusion_idxs)
|
|
||||||
for __ in range(batch_size)]
|
|
||||||
|
|
||||||
# (1) Run the encoder on the src.
|
|
||||||
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
|
|
||||||
self.model.decoder.init_state(src, memory_bank, enc_states)
|
|
||||||
|
|
||||||
results = {
|
|
||||||
"predictions": [],
|
|
||||||
"scores": [],
|
|
||||||
"attention": [],
|
|
||||||
"batch": batch,
|
|
||||||
"gold_score": self._gold_score(
|
|
||||||
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
|
|
||||||
enc_states, batch_size, src)}
|
|
||||||
|
|
||||||
# (2) Repeat src objects `beam_size` times.
|
|
||||||
# We use now batch_size x beam_size (same as fast mode)
|
|
||||||
src_map = (tile(batch.src_map, beam_size, dim=1)
|
|
||||||
if use_src_map else None)
|
|
||||||
self.model.decoder.map_state(
|
|
||||||
lambda state, dim: tile(state, beam_size, dim=dim))
|
|
||||||
|
|
||||||
if isinstance(memory_bank, tuple):
|
|
||||||
memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
|
|
||||||
else:
|
else:
|
||||||
memory_bank = tile(memory_bank, beam_size, dim=1)
|
results["alignment"] = [[] for _ in range(batch_size)]
|
||||||
memory_lengths = tile(src_lengths, beam_size)
|
|
||||||
|
|
||||||
# (3) run the decoder to generate sentences, using beam search.
|
|
||||||
for i in range(self.max_length):
|
|
||||||
if all((b.done for b in beam)):
|
|
||||||
break
|
|
||||||
|
|
||||||
# (a) Construct batch x beam_size nxt words.
|
|
||||||
# Get all the pending current beam words and arrange for forward.
|
|
||||||
|
|
||||||
inp = torch.stack([b.current_predictions for b in beam])
|
|
||||||
inp = inp.view(1, -1, 1)
|
|
||||||
|
|
||||||
# (b) Decode and forward
|
|
||||||
out, beam_attn = self._decode_and_generate(
|
|
||||||
inp, memory_bank, batch, src_vocabs,
|
|
||||||
memory_lengths=memory_lengths, src_map=src_map, step=i
|
|
||||||
)
|
|
||||||
out = out.view(batch_size, beam_size, -1)
|
|
||||||
beam_attn = beam_attn.view(batch_size, beam_size, -1)
|
|
||||||
|
|
||||||
# (c) Advance each beam.
|
|
||||||
select_indices_array = []
|
|
||||||
# Loop over the batch_size number of beam
|
|
||||||
for j, b in enumerate(beam):
|
|
||||||
if not b.done:
|
|
||||||
b.advance(out[j, :],
|
|
||||||
beam_attn.data[j, :, :memory_lengths[j]])
|
|
||||||
select_indices_array.append(
|
|
||||||
b.current_origin + j * beam_size)
|
|
||||||
select_indices = torch.cat(select_indices_array)
|
|
||||||
|
|
||||||
self.model.decoder.map_state(
|
|
||||||
lambda state, dim: state.index_select(dim, select_indices))
|
|
||||||
|
|
||||||
# (4) Extract sentences from beam.
|
|
||||||
for b in beam:
|
|
||||||
scores, ks = b.sort_finished(minimum=self.n_best)
|
|
||||||
hyps, attn = [], []
|
|
||||||
for times, k in ks[:self.n_best]:
|
|
||||||
hyp, att = b.get_hyp(times, k)
|
|
||||||
hyps.append(hyp)
|
|
||||||
attn.append(att)
|
|
||||||
results["predictions"].append(hyps)
|
|
||||||
results["scores"].append(scores)
|
|
||||||
results["attention"].append(attn)
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _score_target(self, batch, memory_bank, src_lengths,
|
def _score_target(self, batch, memory_bank, src_lengths,
|
||||||
|
@ -855,7 +774,7 @@ class Translator(object):
|
||||||
memory_lengths=src_lengths, src_map=src_map)
|
memory_lengths=src_lengths, src_map=src_map)
|
||||||
|
|
||||||
log_probs[:, :, self._tgt_pad_idx] = 0
|
log_probs[:, :, self._tgt_pad_idx] = 0
|
||||||
gold = tgt_in
|
gold = tgt[1:]
|
||||||
gold_scores = log_probs.gather(2, gold)
|
gold_scores = log_probs.gather(2, gold)
|
||||||
gold_scores = gold_scores.sum(dim=0).view(-1)
|
gold_scores = gold_scores.sum(dim=0).view(-1)
|
||||||
|
|
||||||
|
@ -865,31 +784,9 @@ class Translator(object):
|
||||||
if words_total == 0:
|
if words_total == 0:
|
||||||
msg = "%s No words predicted" % (name,)
|
msg = "%s No words predicted" % (name,)
|
||||||
else:
|
else:
|
||||||
|
avg_score = score_total / words_total
|
||||||
|
ppl = np.exp(-score_total.item() / words_total)
|
||||||
msg = ("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
|
msg = ("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
|
||||||
name, score_total / words_total,
|
name, avg_score,
|
||||||
name, math.exp(-score_total / words_total)))
|
name, ppl))
|
||||||
return msg
|
|
||||||
|
|
||||||
def _report_bleu(self, tgt_path):
|
|
||||||
import subprocess
|
|
||||||
base_dir = os.path.abspath(__file__ + "/../../..")
|
|
||||||
# Rollback pointer to the beginning.
|
|
||||||
self.out_file.seek(0)
|
|
||||||
print()
|
|
||||||
|
|
||||||
res = subprocess.check_output(
|
|
||||||
"perl %s/tools/multi-bleu.perl %s" % (base_dir, tgt_path),
|
|
||||||
stdin=self.out_file, shell=True
|
|
||||||
).decode("utf-8")
|
|
||||||
|
|
||||||
msg = ">> " + res.strip()
|
|
||||||
return msg
|
|
||||||
|
|
||||||
def _report_rouge(self, tgt_path):
|
|
||||||
import subprocess
|
|
||||||
path = os.path.split(os.path.realpath(__file__))[0]
|
|
||||||
msg = subprocess.check_output(
|
|
||||||
"python %s/tools/test_rouge.py -r %s -c STDIN" % (path, tgt_path),
|
|
||||||
shell=True, stdin=self.out_file
|
|
||||||
).decode("utf-8").strip()
|
|
||||||
return msg
|
return msg
|
||||||
|
|
|
@ -3,31 +3,47 @@ from snownlp import SnowNLP
|
||||||
import pkuseg
|
import pkuseg
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_str_func(func):
|
||||||
|
"""
|
||||||
|
Wrapper to apply str function to the proper key of return_dict.
|
||||||
|
"""
|
||||||
|
def wrapper(some_dict):
|
||||||
|
some_dict["seg"] = [func(item) for item in some_dict["seg"]]
|
||||||
|
return some_dict
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
# Chinese segmentation
|
# Chinese segmentation
|
||||||
def zh_segmentator(line):
|
@wrap_str_func
|
||||||
|
def zh_segmentator(line, server_model):
|
||||||
return " ".join(pkuseg.pkuseg().cut(line))
|
return " ".join(pkuseg.pkuseg().cut(line))
|
||||||
|
|
||||||
|
|
||||||
# Chinese simplify -> Chinese traditional standard
|
# Chinese simplify -> Chinese traditional standard
|
||||||
def zh_traditional_standard(line):
|
@wrap_str_func
|
||||||
|
def zh_traditional_standard(line, server_model):
|
||||||
return HanLP.convertToTraditionalChinese(line)
|
return HanLP.convertToTraditionalChinese(line)
|
||||||
|
|
||||||
|
|
||||||
# Chinese simplify -> Chinese traditional (HongKong)
|
# Chinese simplify -> Chinese traditional (HongKong)
|
||||||
def zh_traditional_hk(line):
|
@wrap_str_func
|
||||||
|
def zh_traditional_hk(line, server_model):
|
||||||
return HanLP.s2hk(line)
|
return HanLP.s2hk(line)
|
||||||
|
|
||||||
|
|
||||||
# Chinese simplify -> Chinese traditional (Taiwan)
|
# Chinese simplify -> Chinese traditional (Taiwan)
|
||||||
def zh_traditional_tw(line):
|
@wrap_str_func
|
||||||
|
def zh_traditional_tw(line, server_model):
|
||||||
return HanLP.s2tw(line)
|
return HanLP.s2tw(line)
|
||||||
|
|
||||||
|
|
||||||
# Chinese traditional -> Chinese simplify (v1)
|
# Chinese traditional -> Chinese simplify (v1)
|
||||||
def zh_simplify(line):
|
@wrap_str_func
|
||||||
|
def zh_simplify(line, server_model):
|
||||||
return HanLP.convertToSimplifiedChinese(line)
|
return HanLP.convertToSimplifiedChinese(line)
|
||||||
|
|
||||||
|
|
||||||
# Chinese traditional -> Chinese simplify (v2)
|
# Chinese traditional -> Chinese simplify (v2)
|
||||||
def zh_simplify_v2(line):
|
@wrap_str_func
|
||||||
|
def zh_simplify_v2(line, server_model):
|
||||||
return SnowNLP(line).han
|
return SnowNLP(line).han
|
||||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import unicode_literals, print_function
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from onmt.inputters.text_dataset import TextMultiField
|
from onmt.inputters.text_dataset import TextMultiField
|
||||||
|
from onmt.utils.alignment import build_align_pharaoh
|
||||||
|
|
||||||
|
|
||||||
class TranslationBuilder(object):
|
class TranslationBuilder(object):
|
||||||
|
@ -62,14 +63,18 @@ class TranslationBuilder(object):
|
||||||
len(translation_batch["predictions"]))
|
len(translation_batch["predictions"]))
|
||||||
batch_size = batch.batch_size
|
batch_size = batch.batch_size
|
||||||
|
|
||||||
preds, pred_score, attn, gold_score, indices = list(zip(
|
preds, pred_score, attn, align, gold_score, indices = list(zip(
|
||||||
*sorted(zip(translation_batch["predictions"],
|
*sorted(zip(translation_batch["predictions"],
|
||||||
translation_batch["scores"],
|
translation_batch["scores"],
|
||||||
translation_batch["attention"],
|
translation_batch["attention"],
|
||||||
|
translation_batch["alignment"],
|
||||||
translation_batch["gold_score"],
|
translation_batch["gold_score"],
|
||||||
batch.indices.data),
|
batch.indices.data),
|
||||||
key=lambda x: x[-1])))
|
key=lambda x: x[-1])))
|
||||||
|
|
||||||
|
if not any(align): # when align is a empty nested list
|
||||||
|
align = [None] * batch_size
|
||||||
|
|
||||||
# Sorting
|
# Sorting
|
||||||
inds, perm = torch.sort(batch.indices)
|
inds, perm = torch.sort(batch.indices)
|
||||||
if self._has_text_src:
|
if self._has_text_src:
|
||||||
|
@ -91,7 +96,8 @@ class TranslationBuilder(object):
|
||||||
pred_sents = [self._build_target_tokens(
|
pred_sents = [self._build_target_tokens(
|
||||||
src[:, b] if src is not None else None,
|
src[:, b] if src is not None else None,
|
||||||
src_vocab, src_raw,
|
src_vocab, src_raw,
|
||||||
preds[b][n], attn[b][n])
|
preds[b][n],
|
||||||
|
align[b][n] if align[b] is not None else attn[b][n])
|
||||||
for n in range(self.n_best)]
|
for n in range(self.n_best)]
|
||||||
gold_sent = None
|
gold_sent = None
|
||||||
if tgt is not None:
|
if tgt is not None:
|
||||||
|
@ -103,7 +109,7 @@ class TranslationBuilder(object):
|
||||||
translation = Translation(
|
translation = Translation(
|
||||||
src[:, b] if src is not None else None,
|
src[:, b] if src is not None else None,
|
||||||
src_raw, pred_sents, attn[b], pred_score[b],
|
src_raw, pred_sents, attn[b], pred_score[b],
|
||||||
gold_sent, gold_score[b]
|
gold_sent, gold_score[b], align[b]
|
||||||
)
|
)
|
||||||
translations.append(translation)
|
translations.append(translation)
|
||||||
|
|
||||||
|
@ -122,13 +128,15 @@ class Translation(object):
|
||||||
translation.
|
translation.
|
||||||
gold_sent (List[str]): Words from gold translation.
|
gold_sent (List[str]): Words from gold translation.
|
||||||
gold_score (List[float]): Log-prob of gold translation.
|
gold_score (List[float]): Log-prob of gold translation.
|
||||||
|
word_aligns (List[FloatTensor]): Words Alignment distribution for
|
||||||
|
each translation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = ["src", "src_raw", "pred_sents", "attns", "pred_scores",
|
__slots__ = ["src", "src_raw", "pred_sents", "attns", "pred_scores",
|
||||||
"gold_sent", "gold_score"]
|
"gold_sent", "gold_score", "word_aligns"]
|
||||||
|
|
||||||
def __init__(self, src, src_raw, pred_sents,
|
def __init__(self, src, src_raw, pred_sents,
|
||||||
attn, pred_scores, tgt_sent, gold_score):
|
attn, pred_scores, tgt_sent, gold_score, word_aligns):
|
||||||
self.src = src
|
self.src = src
|
||||||
self.src_raw = src_raw
|
self.src_raw = src_raw
|
||||||
self.pred_sents = pred_sents
|
self.pred_sents = pred_sents
|
||||||
|
@ -136,6 +144,7 @@ class Translation(object):
|
||||||
self.pred_scores = pred_scores
|
self.pred_scores = pred_scores
|
||||||
self.gold_sent = tgt_sent
|
self.gold_sent = tgt_sent
|
||||||
self.gold_score = gold_score
|
self.gold_score = gold_score
|
||||||
|
self.word_aligns = word_aligns
|
||||||
|
|
||||||
def log(self, sent_number):
|
def log(self, sent_number):
|
||||||
"""
|
"""
|
||||||
|
@ -150,6 +159,12 @@ class Translation(object):
|
||||||
msg.append('PRED {}: {}\n'.format(sent_number, pred_sent))
|
msg.append('PRED {}: {}\n'.format(sent_number, pred_sent))
|
||||||
msg.append("PRED SCORE: {:.4f}\n".format(best_score))
|
msg.append("PRED SCORE: {:.4f}\n".format(best_score))
|
||||||
|
|
||||||
|
if self.word_aligns is not None:
|
||||||
|
pred_align = self.word_aligns[0]
|
||||||
|
pred_align_pharaoh = build_align_pharaoh(pred_align)
|
||||||
|
pred_align_sent = ' '.join(pred_align_pharaoh)
|
||||||
|
msg.append("ALIGN: {}\n".format(pred_align_sent))
|
||||||
|
|
||||||
if self.gold_sent is not None:
|
if self.gold_sent is not None:
|
||||||
tgt_sent = ' '.join(self.gold_sent)
|
tgt_sent = ' '.join(self.gold_sent)
|
||||||
msg.append('GOLD {}: {}\n'.format(sent_number, tgt_sent))
|
msg.append('GOLD {}: {}\n'.format(sent_number, tgt_sent))
|
||||||
|
|
|
@ -13,8 +13,13 @@ import importlib
|
||||||
import torch
|
import torch
|
||||||
import onmt.opts
|
import onmt.opts
|
||||||
|
|
||||||
|
from itertools import islice
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
from onmt.utils.logging import init_logger
|
from onmt.utils.logging import init_logger
|
||||||
from onmt.utils.misc import set_random_seed
|
from onmt.utils.misc import set_random_seed
|
||||||
|
from onmt.utils.misc import check_model_config
|
||||||
|
from onmt.utils.alignment import to_word_align
|
||||||
from onmt.utils.parse import ArgumentParser
|
from onmt.utils.parse import ArgumentParser
|
||||||
from onmt.translate.translator import build_translator
|
from onmt.translate.translator import build_translator
|
||||||
|
|
||||||
|
@ -69,6 +74,53 @@ class ServerModelError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CTranslate2Translator(object):
|
||||||
|
"""
|
||||||
|
This class wraps the ctranslate2.Translator object to
|
||||||
|
reproduce the onmt.translate.translator API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_path, device, device_index,
|
||||||
|
batch_size, beam_size, n_best, preload=False):
|
||||||
|
import ctranslate2
|
||||||
|
self.translator = ctranslate2.Translator(
|
||||||
|
model_path,
|
||||||
|
device=device,
|
||||||
|
device_index=device_index,
|
||||||
|
inter_threads=1,
|
||||||
|
intra_threads=1,
|
||||||
|
compute_type="default")
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.beam_size = beam_size
|
||||||
|
self.n_best = n_best
|
||||||
|
if preload:
|
||||||
|
# perform a first request to initialize everything
|
||||||
|
dummy_translation = self.translate(["a"])
|
||||||
|
print("Performed a dummy translation to initialize the model",
|
||||||
|
dummy_translation)
|
||||||
|
time.sleep(1)
|
||||||
|
self.translator.unload_model(to_cpu=True)
|
||||||
|
|
||||||
|
def translate(self, texts_to_translate, batch_size=8):
|
||||||
|
batch = [item.split(" ") for item in texts_to_translate]
|
||||||
|
preds = self.translator.translate_batch(
|
||||||
|
batch,
|
||||||
|
max_batch_size=self.batch_size,
|
||||||
|
beam_size=self.beam_size,
|
||||||
|
num_hypotheses=self.n_best
|
||||||
|
)
|
||||||
|
scores = [[item["score"] for item in ex] for ex in preds]
|
||||||
|
predictions = [[" ".join(item["tokens"]) for item in ex]
|
||||||
|
for ex in preds]
|
||||||
|
return scores, predictions
|
||||||
|
|
||||||
|
def to_cpu(self):
|
||||||
|
self.translator.unload_model(to_cpu=True)
|
||||||
|
|
||||||
|
def to_gpu(self):
|
||||||
|
self.translator.load_model()
|
||||||
|
|
||||||
|
|
||||||
class TranslationServer(object):
|
class TranslationServer(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.models = {}
|
self.models = {}
|
||||||
|
@ -89,13 +141,15 @@ class TranslationServer(object):
|
||||||
else:
|
else:
|
||||||
raise ValueError("""Incorrect config file: missing 'models'
|
raise ValueError("""Incorrect config file: missing 'models'
|
||||||
parameter for model #%d""" % i)
|
parameter for model #%d""" % i)
|
||||||
|
check_model_config(conf, self.models_root)
|
||||||
kwargs = {'timeout': conf.get('timeout', None),
|
kwargs = {'timeout': conf.get('timeout', None),
|
||||||
'load': conf.get('load', None),
|
'load': conf.get('load', None),
|
||||||
'preprocess_opt': conf.get('preprocess', None),
|
'preprocess_opt': conf.get('preprocess', None),
|
||||||
'tokenizer_opt': conf.get('tokenizer', None),
|
'tokenizer_opt': conf.get('tokenizer', None),
|
||||||
'postprocess_opt': conf.get('postprocess', None),
|
'postprocess_opt': conf.get('postprocess', None),
|
||||||
'on_timeout': conf.get('on_timeout', None),
|
'on_timeout': conf.get('on_timeout', None),
|
||||||
'model_root': conf.get('model_root', self.models_root)
|
'model_root': conf.get('model_root', self.models_root),
|
||||||
|
'ct2_model': conf.get('ct2_model', None)
|
||||||
}
|
}
|
||||||
kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
|
kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
|
||||||
model_id = conf.get("id", None)
|
model_id = conf.get("id", None)
|
||||||
|
@ -202,11 +256,9 @@ class ServerModel(object):
|
||||||
|
|
||||||
def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
|
def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
|
||||||
postprocess_opt=None, load=False, timeout=-1,
|
postprocess_opt=None, load=False, timeout=-1,
|
||||||
on_timeout="to_cpu", model_root="./"):
|
on_timeout="to_cpu", model_root="./", ct2_model=None):
|
||||||
self.model_root = model_root
|
self.model_root = model_root
|
||||||
self.opt = self.parse_opt(opt)
|
self.opt = self.parse_opt(opt)
|
||||||
if self.opt.n_best > 1:
|
|
||||||
raise ValueError("Values of n_best > 1 are not supported")
|
|
||||||
|
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.preprocess_opt = preprocess_opt
|
self.preprocess_opt = preprocess_opt
|
||||||
|
@ -215,6 +267,9 @@ class ServerModel(object):
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.on_timeout = on_timeout
|
self.on_timeout = on_timeout
|
||||||
|
|
||||||
|
self.ct2_model = os.path.join(model_root, ct2_model) \
|
||||||
|
if ct2_model is not None else None
|
||||||
|
|
||||||
self.unload_timer = None
|
self.unload_timer = None
|
||||||
self.user_opt = opt
|
self.user_opt = opt
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
@ -224,7 +279,8 @@ class ServerModel(object):
|
||||||
else:
|
else:
|
||||||
log_file = None
|
log_file = None
|
||||||
self.logger = init_logger(log_file=log_file,
|
self.logger = init_logger(log_file=log_file,
|
||||||
log_file_level=self.opt.log_file_level)
|
log_file_level=self.opt.log_file_level,
|
||||||
|
rotate=True)
|
||||||
|
|
||||||
self.loading_lock = threading.Event()
|
self.loading_lock = threading.Event()
|
||||||
self.loading_lock.set()
|
self.loading_lock.set()
|
||||||
|
@ -232,67 +288,6 @@ class ServerModel(object):
|
||||||
|
|
||||||
set_random_seed(self.opt.seed, self.opt.cuda)
|
set_random_seed(self.opt.seed, self.opt.cuda)
|
||||||
|
|
||||||
if load:
|
|
||||||
self.load()
|
|
||||||
|
|
||||||
def parse_opt(self, opt):
|
|
||||||
"""Parse the option set passed by the user using `onmt.opts`
|
|
||||||
|
|
||||||
Args:
|
|
||||||
opt (dict): Options passed by the user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
opt (argparse.Namespace): full set of options for the Translator
|
|
||||||
"""
|
|
||||||
|
|
||||||
prec_argv = sys.argv
|
|
||||||
sys.argv = sys.argv[:1]
|
|
||||||
parser = ArgumentParser()
|
|
||||||
onmt.opts.translate_opts(parser)
|
|
||||||
|
|
||||||
models = opt['models']
|
|
||||||
if not isinstance(models, (list, tuple)):
|
|
||||||
models = [models]
|
|
||||||
opt['models'] = [os.path.join(self.model_root, model)
|
|
||||||
for model in models]
|
|
||||||
opt['src'] = "dummy_src"
|
|
||||||
|
|
||||||
for (k, v) in opt.items():
|
|
||||||
if k == 'models':
|
|
||||||
sys.argv += ['-model']
|
|
||||||
sys.argv += [str(model) for model in v]
|
|
||||||
elif type(v) == bool:
|
|
||||||
sys.argv += ['-%s' % k]
|
|
||||||
else:
|
|
||||||
sys.argv += ['-%s' % k, str(v)]
|
|
||||||
|
|
||||||
opt = parser.parse_args()
|
|
||||||
ArgumentParser.validate_translate_opts(opt)
|
|
||||||
opt.cuda = opt.gpu > -1
|
|
||||||
|
|
||||||
sys.argv = prec_argv
|
|
||||||
return opt
|
|
||||||
|
|
||||||
@property
|
|
||||||
def loaded(self):
|
|
||||||
return hasattr(self, 'translator')
|
|
||||||
|
|
||||||
def load(self):
|
|
||||||
self.loading_lock.clear()
|
|
||||||
|
|
||||||
timer = Timer()
|
|
||||||
self.logger.info("Loading model %d" % self.model_id)
|
|
||||||
timer.start()
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.translator = build_translator(self.opt,
|
|
||||||
report_score=False,
|
|
||||||
out_file=codecs.open(
|
|
||||||
os.devnull, "w", "utf-8"))
|
|
||||||
except RuntimeError as e:
|
|
||||||
raise ServerModelError("Runtime Error: %s" % str(e))
|
|
||||||
|
|
||||||
timer.tick("model_loading")
|
|
||||||
if self.preprocess_opt is not None:
|
if self.preprocess_opt is not None:
|
||||||
self.logger.info("Loading preprocessor")
|
self.logger.info("Loading preprocessor")
|
||||||
self.preprocessor = []
|
self.preprocessor = []
|
||||||
|
@ -347,6 +342,77 @@ class ServerModel(object):
|
||||||
function = get_function_by_path(function_path)
|
function = get_function_by_path(function_path)
|
||||||
self.postprocessor.append(function)
|
self.postprocessor.append(function)
|
||||||
|
|
||||||
|
if load:
|
||||||
|
self.load(preload=True)
|
||||||
|
self.stop_unload_timer()
|
||||||
|
|
||||||
|
def parse_opt(self, opt):
|
||||||
|
"""Parse the option set passed by the user using `onmt.opts`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
opt (dict): Options passed by the user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
opt (argparse.Namespace): full set of options for the Translator
|
||||||
|
"""
|
||||||
|
|
||||||
|
prec_argv = sys.argv
|
||||||
|
sys.argv = sys.argv[:1]
|
||||||
|
parser = ArgumentParser()
|
||||||
|
onmt.opts.translate_opts(parser)
|
||||||
|
|
||||||
|
models = opt['models']
|
||||||
|
if not isinstance(models, (list, tuple)):
|
||||||
|
models = [models]
|
||||||
|
opt['models'] = [os.path.join(self.model_root, model)
|
||||||
|
for model in models]
|
||||||
|
opt['src'] = "dummy_src"
|
||||||
|
|
||||||
|
for (k, v) in opt.items():
|
||||||
|
if k == 'models':
|
||||||
|
sys.argv += ['-model']
|
||||||
|
sys.argv += [str(model) for model in v]
|
||||||
|
elif type(v) == bool:
|
||||||
|
sys.argv += ['-%s' % k]
|
||||||
|
else:
|
||||||
|
sys.argv += ['-%s' % k, str(v)]
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
ArgumentParser.validate_translate_opts(opt)
|
||||||
|
opt.cuda = opt.gpu > -1
|
||||||
|
|
||||||
|
sys.argv = prec_argv
|
||||||
|
return opt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loaded(self):
|
||||||
|
return hasattr(self, 'translator')
|
||||||
|
|
||||||
|
def load(self, preload=False):
|
||||||
|
self.loading_lock.clear()
|
||||||
|
|
||||||
|
timer = Timer()
|
||||||
|
self.logger.info("Loading model %d" % self.model_id)
|
||||||
|
timer.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.ct2_model is not None:
|
||||||
|
self.translator = CTranslate2Translator(
|
||||||
|
self.ct2_model,
|
||||||
|
device="cuda" if self.opt.cuda else "cpu",
|
||||||
|
device_index=self.opt.gpu if self.opt.cuda else 0,
|
||||||
|
batch_size=self.opt.batch_size,
|
||||||
|
beam_size=self.opt.beam_size,
|
||||||
|
n_best=self.opt.n_best,
|
||||||
|
preload=preload)
|
||||||
|
else:
|
||||||
|
self.translator = build_translator(
|
||||||
|
self.opt, report_score=False,
|
||||||
|
out_file=codecs.open(os.devnull, "w", "utf-8"))
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise ServerModelError("Runtime Error: %s" % str(e))
|
||||||
|
|
||||||
|
timer.tick("model_loading")
|
||||||
self.load_time = timer.tick()
|
self.load_time = timer.tick()
|
||||||
self.reset_unload_timer()
|
self.reset_unload_timer()
|
||||||
self.loading_lock.set()
|
self.loading_lock.set()
|
||||||
|
@ -390,13 +456,9 @@ class ServerModel(object):
|
||||||
head_spaces = []
|
head_spaces = []
|
||||||
tail_spaces = []
|
tail_spaces = []
|
||||||
sslength = []
|
sslength = []
|
||||||
|
all_preprocessed = []
|
||||||
for i, inp in enumerate(inputs):
|
for i, inp in enumerate(inputs):
|
||||||
src = inp['src']
|
src = inp['src']
|
||||||
if src.strip() == "":
|
|
||||||
head_spaces.append(src)
|
|
||||||
texts.append("")
|
|
||||||
tail_spaces.append("")
|
|
||||||
else:
|
|
||||||
whitespaces_before, whitespaces_after = "", ""
|
whitespaces_before, whitespaces_after = "", ""
|
||||||
match_before = re.search(r'^\s+', src)
|
match_before = re.search(r'^\s+', src)
|
||||||
match_after = re.search(r'\s+$', src)
|
match_after = re.search(r'\s+$', src)
|
||||||
|
@ -405,8 +467,11 @@ class ServerModel(object):
|
||||||
if match_after is not None:
|
if match_after is not None:
|
||||||
whitespaces_after = match_after.group(0)
|
whitespaces_after = match_after.group(0)
|
||||||
head_spaces.append(whitespaces_before)
|
head_spaces.append(whitespaces_before)
|
||||||
preprocessed_src = self.maybe_preprocess(src.strip())
|
# every segment becomes a dict for flexibility purposes
|
||||||
tok = self.maybe_tokenize(preprocessed_src)
|
seg_dict = self.maybe_preprocess(src.strip())
|
||||||
|
all_preprocessed.append(seg_dict)
|
||||||
|
for seg in seg_dict["seg"]:
|
||||||
|
tok = self.maybe_tokenize(seg)
|
||||||
texts.append(tok)
|
texts.append(tok)
|
||||||
sslength.append(len(tok.split()))
|
sslength.append(len(tok.split()))
|
||||||
tail_spaces.append(whitespaces_after)
|
tail_spaces.append(whitespaces_after)
|
||||||
|
@ -441,28 +506,69 @@ class ServerModel(object):
|
||||||
self.reset_unload_timer()
|
self.reset_unload_timer()
|
||||||
|
|
||||||
# NOTE: translator returns lists of `n_best` list
|
# NOTE: translator returns lists of `n_best` list
|
||||||
# we can ignore that (i.e. flatten lists) only because
|
|
||||||
# we restrict `n_best=1`
|
|
||||||
def flatten_list(_list): return sum(_list, [])
|
def flatten_list(_list): return sum(_list, [])
|
||||||
|
tiled_texts = [t for t in texts_to_translate
|
||||||
|
for _ in range(self.opt.n_best)]
|
||||||
results = flatten_list(predictions)
|
results = flatten_list(predictions)
|
||||||
scores = [score_tensor.item()
|
|
||||||
|
def maybe_item(x): return x.item() if type(x) is torch.Tensor else x
|
||||||
|
scores = [maybe_item(score_tensor)
|
||||||
for score_tensor in flatten_list(scores)]
|
for score_tensor in flatten_list(scores)]
|
||||||
|
|
||||||
results = [self.maybe_detokenize(item)
|
results = [self.maybe_detokenize_with_align(result, src)
|
||||||
for item in results]
|
for result, src in zip(results, tiled_texts)]
|
||||||
|
|
||||||
|
aligns = [align for _, align in results]
|
||||||
|
|
||||||
results = [self.maybe_postprocess(item)
|
|
||||||
for item in results]
|
|
||||||
# build back results with empty texts
|
# build back results with empty texts
|
||||||
for i in empty_indices:
|
for i in empty_indices:
|
||||||
results.insert(i, "")
|
j = i * self.opt.n_best
|
||||||
scores.insert(i, 0)
|
results = (results[:j] +
|
||||||
|
[("", None)] * self.opt.n_best + results[j:])
|
||||||
|
aligns = aligns[:j] + [None] * self.opt.n_best + aligns[j:]
|
||||||
|
scores = scores[:j] + [0] * self.opt.n_best + scores[j:]
|
||||||
|
|
||||||
|
rebuilt_segs, scores, aligns = self.rebuild_seg_packages(
|
||||||
|
all_preprocessed, results, scores, aligns, self.opt.n_best)
|
||||||
|
|
||||||
|
results = [self.maybe_postprocess(seg) for seg in rebuilt_segs]
|
||||||
|
|
||||||
|
head_spaces = [h for h in head_spaces for i in range(self.opt.n_best)]
|
||||||
|
tail_spaces = [h for h in tail_spaces for i in range(self.opt.n_best)]
|
||||||
results = ["".join(items)
|
results = ["".join(items)
|
||||||
for items in zip(head_spaces, results, tail_spaces)]
|
for items in zip(head_spaces, results, tail_spaces)]
|
||||||
|
|
||||||
self.logger.info("Translation Results: %d", len(results))
|
self.logger.info("Translation Results: %d", len(results))
|
||||||
return results, scores, self.opt.n_best, timer.times
|
|
||||||
|
return results, scores, self.opt.n_best, timer.times, aligns
|
||||||
|
|
||||||
|
def rebuild_seg_packages(self, all_preprocessed, results,
|
||||||
|
scores, aligns, n_best):
|
||||||
|
"""
|
||||||
|
Rebuild proper segment packages based on initial n_seg.
|
||||||
|
"""
|
||||||
|
offset = 0
|
||||||
|
rebuilt_segs = []
|
||||||
|
avg_scores = []
|
||||||
|
merged_aligns = []
|
||||||
|
for i, seg_dict in enumerate(all_preprocessed):
|
||||||
|
sub_results = results[n_best * offset:
|
||||||
|
(offset + seg_dict["n_seg"]) * n_best]
|
||||||
|
sub_scores = scores[n_best * offset:
|
||||||
|
(offset + seg_dict["n_seg"]) * n_best]
|
||||||
|
sub_aligns = aligns[n_best * offset:
|
||||||
|
(offset + seg_dict["n_seg"]) * n_best]
|
||||||
|
for j in range(n_best):
|
||||||
|
_seg_dict = deepcopy(seg_dict)
|
||||||
|
_sub_segs = list(list(zip(*sub_results))[0])
|
||||||
|
_seg_dict["seg"] = list(islice(_sub_segs, j, None, n_best))
|
||||||
|
rebuilt_segs.append(_seg_dict)
|
||||||
|
sub_sub_scores = list(islice(sub_scores, j, None, n_best))
|
||||||
|
avg_scores.append(sum(sub_sub_scores)/_seg_dict["n_seg"])
|
||||||
|
sub_sub_aligns = list(islice(sub_aligns, j, None, n_best))
|
||||||
|
merged_aligns.append(sub_sub_aligns)
|
||||||
|
offset += _seg_dict["n_seg"]
|
||||||
|
return rebuilt_segs, avg_scores, merged_aligns
|
||||||
|
|
||||||
def do_timeout(self):
|
def do_timeout(self):
|
||||||
"""Timeout function that frees GPU memory.
|
"""Timeout function that frees GPU memory.
|
||||||
|
@ -485,6 +591,7 @@ class ServerModel(object):
|
||||||
del self.translator
|
del self.translator
|
||||||
if self.opt.cuda:
|
if self.opt.cuda:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
self.stop_unload_timer()
|
||||||
self.unload_timer = None
|
self.unload_timer = None
|
||||||
|
|
||||||
def stop_unload_timer(self):
|
def stop_unload_timer(self):
|
||||||
|
@ -515,12 +622,18 @@ class ServerModel(object):
|
||||||
@critical
|
@critical
|
||||||
def to_cpu(self):
|
def to_cpu(self):
|
||||||
"""Move the model to CPU and clear CUDA cache."""
|
"""Move the model to CPU and clear CUDA cache."""
|
||||||
|
if type(self.translator) == CTranslate2Translator:
|
||||||
|
self.translator.to_cpu()
|
||||||
|
else:
|
||||||
self.translator.model.cpu()
|
self.translator.model.cpu()
|
||||||
if self.opt.cuda:
|
if self.opt.cuda:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def to_gpu(self):
|
def to_gpu(self):
|
||||||
"""Move the model to GPU."""
|
"""Move the model to GPU."""
|
||||||
|
if type(self.translator) == CTranslate2Translator:
|
||||||
|
self.translator.to_gpu()
|
||||||
|
else:
|
||||||
torch.cuda.set_device(self.opt.gpu)
|
torch.cuda.set_device(self.opt.gpu)
|
||||||
self.translator.model.cuda()
|
self.translator.model.cuda()
|
||||||
|
|
||||||
|
@ -528,7 +641,11 @@ class ServerModel(object):
|
||||||
"""Preprocess the sequence (or not)
|
"""Preprocess the sequence (or not)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if type(sequence) is str:
|
||||||
|
sequence = {
|
||||||
|
"seg": [sequence],
|
||||||
|
"n_seg": 1
|
||||||
|
}
|
||||||
if self.preprocess_opt is not None:
|
if self.preprocess_opt is not None:
|
||||||
return self.preprocess(sequence)
|
return self.preprocess(sequence)
|
||||||
return sequence
|
return sequence
|
||||||
|
@ -545,7 +662,7 @@ class ServerModel(object):
|
||||||
if self.preprocessor is None:
|
if self.preprocessor is None:
|
||||||
raise ValueError("No preprocessor loaded")
|
raise ValueError("No preprocessor loaded")
|
||||||
for function in self.preprocessor:
|
for function in self.preprocessor:
|
||||||
sequence = function(sequence)
|
sequence = function(sequence, self)
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
def maybe_tokenize(self, sequence):
|
def maybe_tokenize(self, sequence):
|
||||||
|
@ -579,6 +696,42 @@ class ServerModel(object):
|
||||||
tok = " ".join(tok)
|
tok = " ".join(tok)
|
||||||
return tok
|
return tok
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer_marker(self):
|
||||||
|
marker = None
|
||||||
|
tokenizer_type = self.tokenizer_opt.get('type', None)
|
||||||
|
if tokenizer_type == "pyonmttok":
|
||||||
|
params = self.tokenizer_opt.get('params', None)
|
||||||
|
if params is not None:
|
||||||
|
if params.get("joiner_annotate", None) is not None:
|
||||||
|
marker = 'joiner'
|
||||||
|
elif params.get("spacer_annotate", None) is not None:
|
||||||
|
marker = 'spacer'
|
||||||
|
elif tokenizer_type == "sentencepiece":
|
||||||
|
marker = 'spacer'
|
||||||
|
return marker
|
||||||
|
|
||||||
|
def maybe_detokenize_with_align(self, sequence, src):
|
||||||
|
"""De-tokenize (or not) the sequence (with alignment).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence (str): The sequence to detokenize, possible with
|
||||||
|
alignment seperate by ` ||| `.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sequence (str): The detokenized sequence.
|
||||||
|
align (str): The alignment correspand to detokenized src/tgt
|
||||||
|
sorted or None if no alignment in output.
|
||||||
|
"""
|
||||||
|
align = None
|
||||||
|
if self.opt.report_align:
|
||||||
|
# output contain alignment
|
||||||
|
sequence, align = sequence.split(' ||| ')
|
||||||
|
if align != '':
|
||||||
|
align = self.maybe_convert_align(src, sequence, align)
|
||||||
|
sequence = self.maybe_detokenize(sequence)
|
||||||
|
return (sequence, align)
|
||||||
|
|
||||||
def maybe_detokenize(self, sequence):
|
def maybe_detokenize(self, sequence):
|
||||||
"""De-tokenize the sequence (or not)
|
"""De-tokenize the sequence (or not)
|
||||||
|
|
||||||
|
@ -605,14 +758,29 @@ class ServerModel(object):
|
||||||
|
|
||||||
return detok
|
return detok
|
||||||
|
|
||||||
|
def maybe_convert_align(self, src, tgt, align):
|
||||||
|
"""Convert alignment to match detokenized src/tgt (or not).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src (str): The tokenized source sequence.
|
||||||
|
tgt (str): The tokenized target sequence.
|
||||||
|
align (str): The alignment correspand to src/tgt pair.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
align (str): The alignment correspand to detokenized src/tgt.
|
||||||
|
"""
|
||||||
|
if self.tokenizer_marker is not None and ''.join(tgt.split()) != '':
|
||||||
|
return to_word_align(src, tgt, align, mode=self.tokenizer_marker)
|
||||||
|
return align
|
||||||
|
|
||||||
def maybe_postprocess(self, sequence):
|
def maybe_postprocess(self, sequence):
|
||||||
"""Postprocess the sequence (or not)
|
"""Postprocess the sequence (or not)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.postprocess_opt is not None:
|
if self.postprocess_opt is not None:
|
||||||
return self.postprocess(sequence)
|
return self.postprocess(sequence)
|
||||||
return sequence
|
else:
|
||||||
|
return sequence["seg"][0]
|
||||||
|
|
||||||
def postprocess(self, sequence):
|
def postprocess(self, sequence):
|
||||||
"""Preprocess a single sequence.
|
"""Preprocess a single sequence.
|
||||||
|
@ -626,7 +794,7 @@ class ServerModel(object):
|
||||||
if self.postprocessor is None:
|
if self.postprocessor is None:
|
||||||
raise ValueError("No postprocessor loaded")
|
raise ValueError("No postprocessor loaded")
|
||||||
for function in self.postprocessor:
|
for function in self.postprocessor:
|
||||||
sequence = function(sequence)
|
sequence = function(sequence, self)
|
||||||
return sequence
|
return sequence
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,19 +3,19 @@
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import codecs
|
import codecs
|
||||||
import os
|
import os
|
||||||
import math
|
|
||||||
import time
|
import time
|
||||||
from itertools import count
|
import numpy as np
|
||||||
|
from itertools import count, zip_longest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import onmt.model_builder
|
import onmt.model_builder
|
||||||
import onmt.translate.beam
|
|
||||||
import onmt.inputters as inputters
|
import onmt.inputters as inputters
|
||||||
import onmt.decoders.ensemble
|
import onmt.decoders.ensemble
|
||||||
from onmt.translate.beam_search import BeamSearch
|
from onmt.translate.beam_search import BeamSearch
|
||||||
from onmt.translate.random_sampling import RandomSampling
|
from onmt.translate.greedy_search import GreedySearch
|
||||||
from onmt.utils.misc import tile, set_random_seed
|
from onmt.utils.misc import tile, set_random_seed, report_matrix
|
||||||
|
from onmt.utils.alignment import extract_alignment, build_align_pharaoh
|
||||||
from onmt.modules.copy_generator import collapse_copy_scores
|
from onmt.modules.copy_generator import collapse_copy_scores
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,12 +36,32 @@ def build_translator(opt, report_score=True, logger=None, out_file=None):
|
||||||
model_opt,
|
model_opt,
|
||||||
global_scorer=scorer,
|
global_scorer=scorer,
|
||||||
out_file=out_file,
|
out_file=out_file,
|
||||||
|
report_align=opt.report_align,
|
||||||
report_score=report_score,
|
report_score=report_score,
|
||||||
logger=logger
|
logger=logger
|
||||||
)
|
)
|
||||||
return translator
|
return translator
|
||||||
|
|
||||||
|
|
||||||
|
def max_tok_len(new, count, sofar):
|
||||||
|
"""
|
||||||
|
In token batching scheme, the number of sequences is limited
|
||||||
|
such that the total number of src/tgt tokens (including padding)
|
||||||
|
in a batch <= batch_size
|
||||||
|
"""
|
||||||
|
# Maintains the longest src and tgt length in the current batch
|
||||||
|
global max_src_in_batch # this is a hack
|
||||||
|
# Reset current longest length at a new batch (count=1)
|
||||||
|
if count == 1:
|
||||||
|
max_src_in_batch = 0
|
||||||
|
# max_tgt_in_batch = 0
|
||||||
|
# Src: [<bos> w1 ... wN <eos>]
|
||||||
|
max_src_in_batch = max(max_src_in_batch, len(new.src[0]) + 2)
|
||||||
|
# Tgt: [w1 ... wM <eos>]
|
||||||
|
src_elements = count * max_src_in_batch
|
||||||
|
return src_elements
|
||||||
|
|
||||||
|
|
||||||
class Translator(object):
|
class Translator(object):
|
||||||
"""Translate a batch of sentences with a saved model.
|
"""Translate a batch of sentences with a saved model.
|
||||||
|
|
||||||
|
@ -59,9 +79,9 @@ class Translator(object):
|
||||||
:class:`onmt.translate.decode_strategy.DecodeStrategy`.
|
:class:`onmt.translate.decode_strategy.DecodeStrategy`.
|
||||||
beam_size (int): Number of beams.
|
beam_size (int): Number of beams.
|
||||||
random_sampling_topk (int): See
|
random_sampling_topk (int): See
|
||||||
:class:`onmt.translate.random_sampling.RandomSampling`.
|
:class:`onmt.translate.greedy_search.GreedySearch`.
|
||||||
random_sampling_temp (int): See
|
random_sampling_temp (int): See
|
||||||
:class:`onmt.translate.random_sampling.RandomSampling`.
|
:class:`onmt.translate.greedy_search.GreedySearch`.
|
||||||
stepwise_penalty (bool): Whether coverage penalty is applied every step
|
stepwise_penalty (bool): Whether coverage penalty is applied every step
|
||||||
or not.
|
or not.
|
||||||
dump_beam (bool): Debugging option.
|
dump_beam (bool): Debugging option.
|
||||||
|
@ -72,8 +92,6 @@ class Translator(object):
|
||||||
replace_unk (bool): Replace unknown token.
|
replace_unk (bool): Replace unknown token.
|
||||||
data_type (str): Source data type.
|
data_type (str): Source data type.
|
||||||
verbose (bool): Print/log every translation.
|
verbose (bool): Print/log every translation.
|
||||||
report_bleu (bool): Print/log Bleu metric.
|
|
||||||
report_rouge (bool): Print/log Rouge metric.
|
|
||||||
report_time (bool): Print/log total time/frequency.
|
report_time (bool): Print/log total time/frequency.
|
||||||
copy_attn (bool): Use copy attention.
|
copy_attn (bool): Use copy attention.
|
||||||
global_scorer (onmt.translate.GNMTGlobalScorer): Translation
|
global_scorer (onmt.translate.GNMTGlobalScorer): Translation
|
||||||
|
@ -105,12 +123,11 @@ class Translator(object):
|
||||||
phrase_table="",
|
phrase_table="",
|
||||||
data_type="text",
|
data_type="text",
|
||||||
verbose=False,
|
verbose=False,
|
||||||
report_bleu=False,
|
|
||||||
report_rouge=False,
|
|
||||||
report_time=False,
|
report_time=False,
|
||||||
copy_attn=False,
|
copy_attn=False,
|
||||||
global_scorer=None,
|
global_scorer=None,
|
||||||
out_file=None,
|
out_file=None,
|
||||||
|
report_align=False,
|
||||||
report_score=True,
|
report_score=True,
|
||||||
logger=None,
|
logger=None,
|
||||||
seed=-1):
|
seed=-1):
|
||||||
|
@ -153,8 +170,6 @@ class Translator(object):
|
||||||
self.phrase_table = phrase_table
|
self.phrase_table = phrase_table
|
||||||
self.data_type = data_type
|
self.data_type = data_type
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.report_bleu = report_bleu
|
|
||||||
self.report_rouge = report_rouge
|
|
||||||
self.report_time = report_time
|
self.report_time = report_time
|
||||||
|
|
||||||
self.copy_attn = copy_attn
|
self.copy_attn = copy_attn
|
||||||
|
@ -165,6 +180,7 @@ class Translator(object):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Coverage penalty requires an attentional decoder.")
|
"Coverage penalty requires an attentional decoder.")
|
||||||
self.out_file = out_file
|
self.out_file = out_file
|
||||||
|
self.report_align = report_align
|
||||||
self.report_score = report_score
|
self.report_score = report_score
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
|
@ -192,6 +208,7 @@ class Translator(object):
|
||||||
model_opt,
|
model_opt,
|
||||||
global_scorer=None,
|
global_scorer=None,
|
||||||
out_file=None,
|
out_file=None,
|
||||||
|
report_align=False,
|
||||||
report_score=True,
|
report_score=True,
|
||||||
logger=None):
|
logger=None):
|
||||||
"""Alternate constructor.
|
"""Alternate constructor.
|
||||||
|
@ -207,6 +224,7 @@ class Translator(object):
|
||||||
:func:`__init__()`..
|
:func:`__init__()`..
|
||||||
out_file (TextIO or codecs.StreamReaderWriter): See
|
out_file (TextIO or codecs.StreamReaderWriter): See
|
||||||
:func:`__init__()`.
|
:func:`__init__()`.
|
||||||
|
report_align (bool) : See :func:`__init__()`.
|
||||||
report_score (bool) : See :func:`__init__()`.
|
report_score (bool) : See :func:`__init__()`.
|
||||||
logger (logging.Logger or NoneType): See :func:`__init__()`.
|
logger (logging.Logger or NoneType): See :func:`__init__()`.
|
||||||
"""
|
"""
|
||||||
|
@ -234,12 +252,11 @@ class Translator(object):
|
||||||
phrase_table=opt.phrase_table,
|
phrase_table=opt.phrase_table,
|
||||||
data_type=opt.data_type,
|
data_type=opt.data_type,
|
||||||
verbose=opt.verbose,
|
verbose=opt.verbose,
|
||||||
report_bleu=opt.report_bleu,
|
|
||||||
report_rouge=opt.report_rouge,
|
|
||||||
report_time=opt.report_time,
|
report_time=opt.report_time,
|
||||||
copy_attn=model_opt.copy_attn,
|
copy_attn=model_opt.copy_attn,
|
||||||
global_scorer=global_scorer,
|
global_scorer=global_scorer,
|
||||||
out_file=out_file,
|
out_file=out_file,
|
||||||
|
report_align=report_align,
|
||||||
report_score=report_score,
|
report_score=report_score,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
seed=opt.seed)
|
seed=opt.seed)
|
||||||
|
@ -267,7 +284,9 @@ class Translator(object):
|
||||||
tgt=None,
|
tgt=None,
|
||||||
src_dir=None,
|
src_dir=None,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
|
batch_type="sents",
|
||||||
attn_debug=False,
|
attn_debug=False,
|
||||||
|
align_debug=False,
|
||||||
phrase_table=""):
|
phrase_table=""):
|
||||||
"""Translate content of ``src`` and get gold scores from ``tgt``.
|
"""Translate content of ``src`` and get gold scores from ``tgt``.
|
||||||
|
|
||||||
|
@ -278,6 +297,7 @@ class Translator(object):
|
||||||
for certain types of data).
|
for certain types of data).
|
||||||
batch_size (int): size of examples per mini-batch
|
batch_size (int): size of examples per mini-batch
|
||||||
attn_debug (bool): enables the attention logging
|
attn_debug (bool): enables the attention logging
|
||||||
|
align_debug (bool): enables the word alignment logging
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(`list`, `list`)
|
(`list`, `list`)
|
||||||
|
@ -290,12 +310,16 @@ class Translator(object):
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
raise ValueError("batch_size must be set")
|
raise ValueError("batch_size must be set")
|
||||||
|
|
||||||
|
src_data = {"reader": self.src_reader, "data": src, "dir": src_dir}
|
||||||
|
tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None}
|
||||||
|
_readers, _data, _dir = inputters.Dataset.config(
|
||||||
|
[('src', src_data), ('tgt', tgt_data)])
|
||||||
|
|
||||||
|
# corpus_id field is useless here
|
||||||
|
if self.fields.get("corpus_id", None) is not None:
|
||||||
|
self.fields.pop('corpus_id')
|
||||||
data = inputters.Dataset(
|
data = inputters.Dataset(
|
||||||
self.fields,
|
self.fields, readers=_readers, data=_data, dirs=_dir,
|
||||||
readers=([self.src_reader, self.tgt_reader]
|
|
||||||
if tgt else [self.src_reader]),
|
|
||||||
data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)],
|
|
||||||
dirs=[src_dir, None] if tgt else [src_dir],
|
|
||||||
sort_key=inputters.str2sortkey[self.data_type],
|
sort_key=inputters.str2sortkey[self.data_type],
|
||||||
filter_pred=self._filter_pred
|
filter_pred=self._filter_pred
|
||||||
)
|
)
|
||||||
|
@ -304,6 +328,7 @@ class Translator(object):
|
||||||
dataset=data,
|
dataset=data,
|
||||||
device=self._dev,
|
device=self._dev,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
batch_size_fn=max_tok_len if batch_type == "tokens" else None,
|
||||||
train=False,
|
train=False,
|
||||||
sort=False,
|
sort=False,
|
||||||
sort_within_batch=True,
|
sort_within_batch=True,
|
||||||
|
@ -341,6 +366,14 @@ class Translator(object):
|
||||||
|
|
||||||
n_best_preds = [" ".join(pred)
|
n_best_preds = [" ".join(pred)
|
||||||
for pred in trans.pred_sents[:self.n_best]]
|
for pred in trans.pred_sents[:self.n_best]]
|
||||||
|
if self.report_align:
|
||||||
|
align_pharaohs = [build_align_pharaoh(align) for align
|
||||||
|
in trans.word_aligns[:self.n_best]]
|
||||||
|
n_best_preds_align = [" ".join(align) for align
|
||||||
|
in align_pharaohs]
|
||||||
|
n_best_preds = [pred + " ||| " + align
|
||||||
|
for pred, align in zip(
|
||||||
|
n_best_preds, n_best_preds_align)]
|
||||||
all_predictions += [n_best_preds]
|
all_predictions += [n_best_preds]
|
||||||
self.out_file.write('\n'.join(n_best_preds) + '\n')
|
self.out_file.write('\n'.join(n_best_preds) + '\n')
|
||||||
self.out_file.flush()
|
self.out_file.flush()
|
||||||
|
@ -361,17 +394,23 @@ class Translator(object):
|
||||||
srcs = trans.src_raw
|
srcs = trans.src_raw
|
||||||
else:
|
else:
|
||||||
srcs = [str(item) for item in range(len(attns[0]))]
|
srcs = [str(item) for item in range(len(attns[0]))]
|
||||||
header_format = "{:>10.10} " + "{:>10.7} " * len(srcs)
|
output = report_matrix(srcs, preds, attns)
|
||||||
row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
|
if self.logger:
|
||||||
output = header_format.format("", *srcs) + '\n'
|
self.logger.info(output)
|
||||||
for word, row in zip(preds, attns):
|
else:
|
||||||
max_index = row.index(max(row))
|
os.write(1, output.encode('utf-8'))
|
||||||
row_format = row_format.replace(
|
|
||||||
"{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
|
if align_debug:
|
||||||
row_format = row_format.replace(
|
if trans.gold_sent is not None:
|
||||||
"{:*>10.7f} ", "{:>10.7f} ", max_index)
|
tgts = trans.gold_sent
|
||||||
output += row_format.format(word, *row) + '\n'
|
else:
|
||||||
row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
|
tgts = trans.pred_sents[0]
|
||||||
|
align = trans.word_aligns[0].tolist()
|
||||||
|
if self.data_type == 'text':
|
||||||
|
srcs = trans.src_raw
|
||||||
|
else:
|
||||||
|
srcs = [str(item) for item in range(len(align[0]))]
|
||||||
|
output = report_matrix(srcs, tgts, align)
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.info(output)
|
self.logger.info(output)
|
||||||
else:
|
else:
|
||||||
|
@ -387,12 +426,6 @@ class Translator(object):
|
||||||
msg = self._report_score('GOLD', gold_score_total,
|
msg = self._report_score('GOLD', gold_score_total,
|
||||||
gold_words_total)
|
gold_words_total)
|
||||||
self._log(msg)
|
self._log(msg)
|
||||||
if self.report_bleu:
|
|
||||||
msg = self._report_bleu(tgt)
|
|
||||||
self._log(msg)
|
|
||||||
if self.report_rouge:
|
|
||||||
msg = self._report_rouge(tgt)
|
|
||||||
self._log(msg)
|
|
||||||
|
|
||||||
if self.report_time:
|
if self.report_time:
|
||||||
total_time = end_time - start_time
|
total_time = end_time - start_time
|
||||||
|
@ -408,119 +441,113 @@ class Translator(object):
|
||||||
codecs.open(self.dump_beam, 'w', 'utf-8'))
|
codecs.open(self.dump_beam, 'w', 'utf-8'))
|
||||||
return all_scores, all_predictions
|
return all_scores, all_predictions
|
||||||
|
|
||||||
def _translate_random_sampling(
|
def _align_pad_prediction(self, predictions, bos, pad):
|
||||||
self,
|
"""
|
||||||
batch,
|
Padding predictions in batch and add BOS.
|
||||||
src_vocabs,
|
|
||||||
max_length,
|
|
||||||
min_length=0,
|
|
||||||
sampling_temp=1.0,
|
|
||||||
keep_topk=-1,
|
|
||||||
return_attention=False):
|
|
||||||
"""Alternative to beam search. Do random sampling at each step."""
|
|
||||||
|
|
||||||
assert self.beam_size == 1
|
Args:
|
||||||
|
predictions (List[List[Tensor]]): `(batch, n_best,)`, for each src
|
||||||
|
sequence contain n_best tgt predictions all of which ended with
|
||||||
|
eos id.
|
||||||
|
bos (int): bos index to be used.
|
||||||
|
pad (int): pad index to be used.
|
||||||
|
|
||||||
# TODO: support these blacklisted features.
|
Return:
|
||||||
assert self.block_ngram_repeat == 0
|
batched_nbest_predict (torch.LongTensor): `(batch, n_best, tgt_l)`
|
||||||
|
"""
|
||||||
|
dtype, device = predictions[0][0].dtype, predictions[0][0].device
|
||||||
|
flatten_tgt = [best.tolist() for bests in predictions
|
||||||
|
for best in bests]
|
||||||
|
paded_tgt = torch.tensor(
|
||||||
|
list(zip_longest(*flatten_tgt, fillvalue=pad)),
|
||||||
|
dtype=dtype, device=device).T
|
||||||
|
bos_tensor = torch.full([paded_tgt.size(0), 1], bos,
|
||||||
|
dtype=dtype, device=device)
|
||||||
|
full_tgt = torch.cat((bos_tensor, paded_tgt), dim=-1)
|
||||||
|
batched_nbest_predict = full_tgt.view(
|
||||||
|
len(predictions), -1, full_tgt.size(-1)) # (batch, n_best, tgt_l)
|
||||||
|
return batched_nbest_predict
|
||||||
|
|
||||||
batch_size = batch.batch_size
|
def _align_forward(self, batch, predictions):
|
||||||
|
"""
|
||||||
|
For a batch of input and its prediction, return a list of batch predict
|
||||||
|
alignment src indice Tensor in size ``(batch, n_best,)``.
|
||||||
|
"""
|
||||||
|
# (0) add BOS and padding to tgt prediction
|
||||||
|
if hasattr(batch, 'tgt'):
|
||||||
|
batch_tgt_idxs = batch.tgt.transpose(1, 2).transpose(0, 2)
|
||||||
|
else:
|
||||||
|
batch_tgt_idxs = self._align_pad_prediction(
|
||||||
|
predictions, bos=self._tgt_bos_idx, pad=self._tgt_pad_idx)
|
||||||
|
tgt_mask = (batch_tgt_idxs.eq(self._tgt_pad_idx) |
|
||||||
|
batch_tgt_idxs.eq(self._tgt_eos_idx) |
|
||||||
|
batch_tgt_idxs.eq(self._tgt_bos_idx))
|
||||||
|
|
||||||
# Encoder forward.
|
n_best = batch_tgt_idxs.size(1)
|
||||||
|
# (1) Encoder forward.
|
||||||
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
|
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
|
||||||
|
|
||||||
|
# (2) Repeat src objects `n_best` times.
|
||||||
|
# We use batch_size x n_best, get ``(src_len, batch * n_best, nfeat)``
|
||||||
|
src = tile(src, n_best, dim=1)
|
||||||
|
enc_states = tile(enc_states, n_best, dim=1)
|
||||||
|
if isinstance(memory_bank, tuple):
|
||||||
|
memory_bank = tuple(tile(x, n_best, dim=1) for x in memory_bank)
|
||||||
|
else:
|
||||||
|
memory_bank = tile(memory_bank, n_best, dim=1)
|
||||||
|
src_lengths = tile(src_lengths, n_best) # ``(batch * n_best,)``
|
||||||
|
|
||||||
|
# (3) Init decoder with n_best src,
|
||||||
self.model.decoder.init_state(src, memory_bank, enc_states)
|
self.model.decoder.init_state(src, memory_bank, enc_states)
|
||||||
|
# reshape tgt to ``(len, batch * n_best, nfeat)``
|
||||||
|
tgt = batch_tgt_idxs.view(-1, batch_tgt_idxs.size(-1)).T.unsqueeze(-1)
|
||||||
|
dec_in = tgt[:-1] # exclude last target from inputs
|
||||||
|
_, attns = self.model.decoder(
|
||||||
|
dec_in, memory_bank, memory_lengths=src_lengths, with_align=True)
|
||||||
|
|
||||||
use_src_map = self.copy_attn
|
alignment_attn = attns["align"] # ``(B, tgt_len-1, src_len)``
|
||||||
|
# masked_select
|
||||||
results = {
|
align_tgt_mask = tgt_mask.view(-1, tgt_mask.size(-1))
|
||||||
"predictions": None,
|
prediction_mask = align_tgt_mask[:, 1:] # exclude bos to match pred
|
||||||
"scores": None,
|
# get aligned src id for each prediction's valid tgt tokens
|
||||||
"attention": None,
|
alignement = extract_alignment(
|
||||||
"batch": batch,
|
alignment_attn, prediction_mask, src_lengths, n_best)
|
||||||
"gold_score": self._gold_score(
|
return alignement
|
||||||
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
|
|
||||||
enc_states, batch_size, src)}
|
|
||||||
|
|
||||||
memory_lengths = src_lengths
|
|
||||||
src_map = batch.src_map if use_src_map else None
|
|
||||||
|
|
||||||
if isinstance(memory_bank, tuple):
|
|
||||||
mb_device = memory_bank[0].device
|
|
||||||
else:
|
|
||||||
mb_device = memory_bank.device
|
|
||||||
|
|
||||||
random_sampler = RandomSampling(
|
|
||||||
self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx,
|
|
||||||
batch_size, mb_device, min_length, self.block_ngram_repeat,
|
|
||||||
self._exclusion_idxs, return_attention, self.max_length,
|
|
||||||
sampling_temp, keep_topk, memory_lengths)
|
|
||||||
|
|
||||||
for step in range(max_length):
|
|
||||||
# Shape: (1, B, 1)
|
|
||||||
decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1)
|
|
||||||
|
|
||||||
log_probs, attn = self._decode_and_generate(
|
|
||||||
decoder_input,
|
|
||||||
memory_bank,
|
|
||||||
batch,
|
|
||||||
src_vocabs,
|
|
||||||
memory_lengths=memory_lengths,
|
|
||||||
src_map=src_map,
|
|
||||||
step=step,
|
|
||||||
batch_offset=random_sampler.select_indices
|
|
||||||
)
|
|
||||||
|
|
||||||
random_sampler.advance(log_probs, attn)
|
|
||||||
any_batch_is_finished = random_sampler.is_finished.any()
|
|
||||||
if any_batch_is_finished:
|
|
||||||
random_sampler.update_finished()
|
|
||||||
if random_sampler.done:
|
|
||||||
break
|
|
||||||
|
|
||||||
if any_batch_is_finished:
|
|
||||||
select_indices = random_sampler.select_indices
|
|
||||||
|
|
||||||
# Reorder states.
|
|
||||||
if isinstance(memory_bank, tuple):
|
|
||||||
memory_bank = tuple(x.index_select(1, select_indices)
|
|
||||||
for x in memory_bank)
|
|
||||||
else:
|
|
||||||
memory_bank = memory_bank.index_select(1, select_indices)
|
|
||||||
|
|
||||||
memory_lengths = memory_lengths.index_select(0, select_indices)
|
|
||||||
|
|
||||||
if src_map is not None:
|
|
||||||
src_map = src_map.index_select(1, select_indices)
|
|
||||||
|
|
||||||
self.model.decoder.map_state(
|
|
||||||
lambda state, dim: state.index_select(dim, select_indices))
|
|
||||||
|
|
||||||
results["scores"] = random_sampler.scores
|
|
||||||
results["predictions"] = random_sampler.predictions
|
|
||||||
results["attention"] = random_sampler.attention
|
|
||||||
return results
|
|
||||||
|
|
||||||
def translate_batch(self, batch, src_vocabs, attn_debug):
|
def translate_batch(self, batch, src_vocabs, attn_debug):
|
||||||
"""Translate a batch of sentences."""
|
"""Translate a batch of sentences."""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.beam_size == 1:
|
if self.beam_size == 1:
|
||||||
return self._translate_random_sampling(
|
decode_strategy = GreedySearch(
|
||||||
batch,
|
pad=self._tgt_pad_idx,
|
||||||
src_vocabs,
|
bos=self._tgt_bos_idx,
|
||||||
self.max_length,
|
eos=self._tgt_eos_idx,
|
||||||
min_length=self.min_length,
|
batch_size=batch.batch_size,
|
||||||
|
min_length=self.min_length, max_length=self.max_length,
|
||||||
|
block_ngram_repeat=self.block_ngram_repeat,
|
||||||
|
exclusion_tokens=self._exclusion_idxs,
|
||||||
|
return_attention=attn_debug or self.replace_unk,
|
||||||
sampling_temp=self.random_sampling_temp,
|
sampling_temp=self.random_sampling_temp,
|
||||||
keep_topk=self.sample_from_topk,
|
keep_topk=self.sample_from_topk)
|
||||||
return_attention=attn_debug or self.replace_unk)
|
|
||||||
else:
|
else:
|
||||||
return self._translate_batch(
|
# TODO: support these blacklisted features
|
||||||
batch,
|
assert not self.dump_beam
|
||||||
src_vocabs,
|
decode_strategy = BeamSearch(
|
||||||
self.max_length,
|
self.beam_size,
|
||||||
min_length=self.min_length,
|
batch_size=batch.batch_size,
|
||||||
ratio=self.ratio,
|
pad=self._tgt_pad_idx,
|
||||||
|
bos=self._tgt_bos_idx,
|
||||||
|
eos=self._tgt_eos_idx,
|
||||||
n_best=self.n_best,
|
n_best=self.n_best,
|
||||||
return_attention=attn_debug or self.replace_unk)
|
global_scorer=self.global_scorer,
|
||||||
|
min_length=self.min_length, max_length=self.max_length,
|
||||||
|
return_attention=attn_debug or self.replace_unk,
|
||||||
|
block_ngram_repeat=self.block_ngram_repeat,
|
||||||
|
exclusion_tokens=self._exclusion_idxs,
|
||||||
|
stepwise_penalty=self.stepwise_penalty,
|
||||||
|
ratio=self.ratio)
|
||||||
|
return self._translate_batch_with_strategy(batch, src_vocabs,
|
||||||
|
decode_strategy)
|
||||||
|
|
||||||
def _run_encoder(self, batch):
|
def _run_encoder(self, batch):
|
||||||
src, src_lengths = batch.src if isinstance(batch.src, tuple) \
|
src, src_lengths = batch.src if isinstance(batch.src, tuple) \
|
||||||
|
@ -577,7 +604,8 @@ class Translator(object):
|
||||||
src_map)
|
src_map)
|
||||||
# here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
|
# here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
|
||||||
if batch_offset is None:
|
if batch_offset is None:
|
||||||
scores = scores.view(batch.batch_size, -1, scores.size(-1))
|
scores = scores.view(-1, batch.batch_size, scores.size(-1))
|
||||||
|
scores = scores.transpose(0, 1).contiguous()
|
||||||
else:
|
else:
|
||||||
scores = scores.view(-1, self.beam_size, scores.size(-1))
|
scores = scores.view(-1, self.beam_size, scores.size(-1))
|
||||||
scores = collapse_copy_scores(
|
scores = collapse_copy_scores(
|
||||||
|
@ -594,21 +622,25 @@ class Translator(object):
|
||||||
# or [ tgt_len, batch_size, vocab ] when full sentence
|
# or [ tgt_len, batch_size, vocab ] when full sentence
|
||||||
return log_probs, attn
|
return log_probs, attn
|
||||||
|
|
||||||
def _translate_batch(
|
def _translate_batch_with_strategy(
|
||||||
self,
|
self,
|
||||||
batch,
|
batch,
|
||||||
src_vocabs,
|
src_vocabs,
|
||||||
max_length,
|
decode_strategy):
|
||||||
min_length=0,
|
"""Translate a batch of sentences step by step using cache.
|
||||||
ratio=0.,
|
|
||||||
n_best=1,
|
|
||||||
return_attention=False):
|
|
||||||
# TODO: support these blacklisted features.
|
|
||||||
assert not self.dump_beam
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: a batch of sentences, yield by data iterator.
|
||||||
|
src_vocabs (list): list of torchtext.data.Vocab if can_copy.
|
||||||
|
decode_strategy (DecodeStrategy): A decode strategy to use for
|
||||||
|
generate translation step by step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
results (dict): The translation results.
|
||||||
|
"""
|
||||||
# (0) Prep the components of the search.
|
# (0) Prep the components of the search.
|
||||||
use_src_map = self.copy_attn
|
use_src_map = self.copy_attn
|
||||||
beam_size = self.beam_size
|
parallel_paths = decode_strategy.parallel_paths # beam_size
|
||||||
batch_size = batch.batch_size
|
batch_size = batch.batch_size
|
||||||
|
|
||||||
# (1) Run the encoder on the src.
|
# (1) Run the encoder on the src.
|
||||||
|
@ -624,42 +656,16 @@ class Translator(object):
|
||||||
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
|
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
|
||||||
enc_states, batch_size, src)}
|
enc_states, batch_size, src)}
|
||||||
|
|
||||||
# (2) Repeat src objects `beam_size` times.
|
# (2) prep decode_strategy. Possibly repeat src objects.
|
||||||
# We use batch_size x beam_size
|
src_map = batch.src_map if use_src_map else None
|
||||||
src_map = (tile(batch.src_map, beam_size, dim=1)
|
fn_map_state, memory_bank, memory_lengths, src_map = \
|
||||||
if use_src_map else None)
|
decode_strategy.initialize(memory_bank, src_lengths, src_map)
|
||||||
self.model.decoder.map_state(
|
if fn_map_state is not None:
|
||||||
lambda state, dim: tile(state, beam_size, dim=dim))
|
self.model.decoder.map_state(fn_map_state)
|
||||||
|
|
||||||
if isinstance(memory_bank, tuple):
|
# (3) Begin decoding step by step:
|
||||||
memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
|
for step in range(decode_strategy.max_length):
|
||||||
mb_device = memory_bank[0].device
|
decoder_input = decode_strategy.current_predictions.view(1, -1, 1)
|
||||||
else:
|
|
||||||
memory_bank = tile(memory_bank, beam_size, dim=1)
|
|
||||||
mb_device = memory_bank.device
|
|
||||||
memory_lengths = tile(src_lengths, beam_size)
|
|
||||||
|
|
||||||
# (0) pt 2, prep the beam object
|
|
||||||
beam = BeamSearch(
|
|
||||||
beam_size,
|
|
||||||
n_best=n_best,
|
|
||||||
batch_size=batch_size,
|
|
||||||
global_scorer=self.global_scorer,
|
|
||||||
pad=self._tgt_pad_idx,
|
|
||||||
eos=self._tgt_eos_idx,
|
|
||||||
bos=self._tgt_bos_idx,
|
|
||||||
min_length=min_length,
|
|
||||||
ratio=ratio,
|
|
||||||
max_length=max_length,
|
|
||||||
mb_device=mb_device,
|
|
||||||
return_attention=return_attention,
|
|
||||||
stepwise_penalty=self.stepwise_penalty,
|
|
||||||
block_ngram_repeat=self.block_ngram_repeat,
|
|
||||||
exclusion_tokens=self._exclusion_idxs,
|
|
||||||
memory_lengths=memory_lengths)
|
|
||||||
|
|
||||||
for step in range(max_length):
|
|
||||||
decoder_input = beam.current_predictions.view(1, -1, 1)
|
|
||||||
|
|
||||||
log_probs, attn = self._decode_and_generate(
|
log_probs, attn = self._decode_and_generate(
|
||||||
decoder_input,
|
decoder_input,
|
||||||
|
@ -669,18 +675,18 @@ class Translator(object):
|
||||||
memory_lengths=memory_lengths,
|
memory_lengths=memory_lengths,
|
||||||
src_map=src_map,
|
src_map=src_map,
|
||||||
step=step,
|
step=step,
|
||||||
batch_offset=beam._batch_offset)
|
batch_offset=decode_strategy.batch_offset)
|
||||||
|
|
||||||
beam.advance(log_probs, attn)
|
decode_strategy.advance(log_probs, attn)
|
||||||
any_beam_is_finished = beam.is_finished.any()
|
any_finished = decode_strategy.is_finished.any()
|
||||||
if any_beam_is_finished:
|
if any_finished:
|
||||||
beam.update_finished()
|
decode_strategy.update_finished()
|
||||||
if beam.done:
|
if decode_strategy.done:
|
||||||
break
|
break
|
||||||
|
|
||||||
select_indices = beam.current_origin
|
select_indices = decode_strategy.select_indices
|
||||||
|
|
||||||
if any_beam_is_finished:
|
if any_finished:
|
||||||
# Reorder states.
|
# Reorder states.
|
||||||
if isinstance(memory_bank, tuple):
|
if isinstance(memory_bank, tuple):
|
||||||
memory_bank = tuple(x.index_select(1, select_indices)
|
memory_bank = tuple(x.index_select(1, select_indices)
|
||||||
|
@ -693,107 +699,18 @@ class Translator(object):
|
||||||
if src_map is not None:
|
if src_map is not None:
|
||||||
src_map = src_map.index_select(1, select_indices)
|
src_map = src_map.index_select(1, select_indices)
|
||||||
|
|
||||||
|
if parallel_paths > 1 or any_finished:
|
||||||
self.model.decoder.map_state(
|
self.model.decoder.map_state(
|
||||||
lambda state, dim: state.index_select(dim, select_indices))
|
lambda state, dim: state.index_select(dim, select_indices))
|
||||||
|
|
||||||
results["scores"] = beam.scores
|
results["scores"] = decode_strategy.scores
|
||||||
results["predictions"] = beam.predictions
|
results["predictions"] = decode_strategy.predictions
|
||||||
results["attention"] = beam.attention
|
results["attention"] = decode_strategy.attention
|
||||||
return results
|
if self.report_align:
|
||||||
|
results["alignment"] = self._align_forward(
|
||||||
# This is left in the code for now, but unsued
|
batch, decode_strategy.predictions)
|
||||||
def _translate_batch_deprecated(self, batch, src_vocabs):
|
|
||||||
# (0) Prep each of the components of the search.
|
|
||||||
# And helper method for reducing verbosity.
|
|
||||||
use_src_map = self.copy_attn
|
|
||||||
beam_size = self.beam_size
|
|
||||||
batch_size = batch.batch_size
|
|
||||||
|
|
||||||
beam = [onmt.translate.Beam(
|
|
||||||
beam_size,
|
|
||||||
n_best=self.n_best,
|
|
||||||
cuda=self.cuda,
|
|
||||||
global_scorer=self.global_scorer,
|
|
||||||
pad=self._tgt_pad_idx,
|
|
||||||
eos=self._tgt_eos_idx,
|
|
||||||
bos=self._tgt_bos_idx,
|
|
||||||
min_length=self.min_length,
|
|
||||||
stepwise_penalty=self.stepwise_penalty,
|
|
||||||
block_ngram_repeat=self.block_ngram_repeat,
|
|
||||||
exclusion_tokens=self._exclusion_idxs)
|
|
||||||
for __ in range(batch_size)]
|
|
||||||
|
|
||||||
# (1) Run the encoder on the src.
|
|
||||||
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
|
|
||||||
self.model.decoder.init_state(src, memory_bank, enc_states)
|
|
||||||
|
|
||||||
results = {
|
|
||||||
"predictions": [],
|
|
||||||
"scores": [],
|
|
||||||
"attention": [],
|
|
||||||
"batch": batch,
|
|
||||||
"gold_score": self._gold_score(
|
|
||||||
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
|
|
||||||
enc_states, batch_size, src)}
|
|
||||||
|
|
||||||
# (2) Repeat src objects `beam_size` times.
|
|
||||||
# We use now batch_size x beam_size (same as fast mode)
|
|
||||||
src_map = (tile(batch.src_map, beam_size, dim=1)
|
|
||||||
if use_src_map else None)
|
|
||||||
self.model.decoder.map_state(
|
|
||||||
lambda state, dim: tile(state, beam_size, dim=dim))
|
|
||||||
|
|
||||||
if isinstance(memory_bank, tuple):
|
|
||||||
memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
|
|
||||||
else:
|
else:
|
||||||
memory_bank = tile(memory_bank, beam_size, dim=1)
|
results["alignment"] = [[] for _ in range(batch_size)]
|
||||||
memory_lengths = tile(src_lengths, beam_size)
|
|
||||||
|
|
||||||
# (3) run the decoder to generate sentences, using beam search.
|
|
||||||
for i in range(self.max_length):
|
|
||||||
if all((b.done for b in beam)):
|
|
||||||
break
|
|
||||||
|
|
||||||
# (a) Construct batch x beam_size nxt words.
|
|
||||||
# Get all the pending current beam words and arrange for forward.
|
|
||||||
|
|
||||||
inp = torch.stack([b.current_predictions for b in beam])
|
|
||||||
inp = inp.view(1, -1, 1)
|
|
||||||
|
|
||||||
# (b) Decode and forward
|
|
||||||
out, beam_attn = self._decode_and_generate(
|
|
||||||
inp, memory_bank, batch, src_vocabs,
|
|
||||||
memory_lengths=memory_lengths, src_map=src_map, step=i
|
|
||||||
)
|
|
||||||
out = out.view(batch_size, beam_size, -1)
|
|
||||||
beam_attn = beam_attn.view(batch_size, beam_size, -1)
|
|
||||||
|
|
||||||
# (c) Advance each beam.
|
|
||||||
select_indices_array = []
|
|
||||||
# Loop over the batch_size number of beam
|
|
||||||
for j, b in enumerate(beam):
|
|
||||||
if not b.done:
|
|
||||||
b.advance(out[j, :],
|
|
||||||
beam_attn.data[j, :, :memory_lengths[j]])
|
|
||||||
select_indices_array.append(
|
|
||||||
b.current_origin + j * beam_size)
|
|
||||||
select_indices = torch.cat(select_indices_array)
|
|
||||||
|
|
||||||
self.model.decoder.map_state(
|
|
||||||
lambda state, dim: state.index_select(dim, select_indices))
|
|
||||||
|
|
||||||
# (4) Extract sentences from beam.
|
|
||||||
for b in beam:
|
|
||||||
scores, ks = b.sort_finished(minimum=self.n_best)
|
|
||||||
hyps, attn = [], []
|
|
||||||
for times, k in ks[:self.n_best]:
|
|
||||||
hyp, att = b.get_hyp(times, k)
|
|
||||||
hyps.append(hyp)
|
|
||||||
attn.append(att)
|
|
||||||
results["predictions"].append(hyps)
|
|
||||||
results["scores"].append(scores)
|
|
||||||
results["attention"].append(attn)
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _score_target(self, batch, memory_bank, src_lengths,
|
def _score_target(self, batch, memory_bank, src_lengths,
|
||||||
|
@ -816,31 +733,9 @@ class Translator(object):
|
||||||
if words_total == 0:
|
if words_total == 0:
|
||||||
msg = "%s No words predicted" % (name,)
|
msg = "%s No words predicted" % (name,)
|
||||||
else:
|
else:
|
||||||
|
avg_score = score_total / words_total
|
||||||
|
ppl = np.exp(-score_total.item() / words_total)
|
||||||
msg = ("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
|
msg = ("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
|
||||||
name, score_total / words_total,
|
name, avg_score,
|
||||||
name, math.exp(-score_total / words_total)))
|
name, ppl))
|
||||||
return msg
|
|
||||||
|
|
||||||
def _report_bleu(self, tgt_path):
|
|
||||||
import subprocess
|
|
||||||
base_dir = os.path.abspath(__file__ + "/../../..")
|
|
||||||
# Rollback pointer to the beginning.
|
|
||||||
self.out_file.seek(0)
|
|
||||||
print()
|
|
||||||
|
|
||||||
res = subprocess.check_output(
|
|
||||||
"perl %s/tools/multi-bleu.perl %s" % (base_dir, tgt_path),
|
|
||||||
stdin=self.out_file, shell=True
|
|
||||||
).decode("utf-8")
|
|
||||||
|
|
||||||
msg = ">> " + res.strip()
|
|
||||||
return msg
|
|
||||||
|
|
||||||
def _report_rouge(self, tgt_path):
|
|
||||||
import subprocess
|
|
||||||
path = os.path.split(os.path.realpath(__file__))[0]
|
|
||||||
msg = subprocess.check_output(
|
|
||||||
"python %s/tools/test_rouge.py -r %s -c STDIN" % (path, tgt_path),
|
|
||||||
shell=True, stdin=self.out_file
|
|
||||||
).decode("utf-8").strip()
|
|
||||||
return msg
|
return msg
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""Module defining various utilities."""
|
"""Module defining various utilities."""
|
||||||
from onmt.utils.misc import split_corpus, aeq, use_gpu, set_random_seed
|
from onmt.utils.misc import split_corpus, aeq, use_gpu, set_random_seed
|
||||||
|
from onmt.utils.alignment import make_batch_align_matrix
|
||||||
from onmt.utils.report_manager import ReportMgr, build_report_manager
|
from onmt.utils.report_manager import ReportMgr, build_report_manager
|
||||||
from onmt.utils.statistics import Statistics
|
from onmt.utils.statistics import Statistics
|
||||||
from onmt.utils.optimizers import MultipleOptimizer, \
|
from onmt.utils.optimizers import MultipleOptimizer, \
|
||||||
|
@ -9,4 +10,4 @@ from onmt.utils.earlystopping import EarlyStopping, scorers_from_opts
|
||||||
__all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr",
|
__all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr",
|
||||||
"build_report_manager", "Statistics",
|
"build_report_manager", "Statistics",
|
||||||
"MultipleOptimizer", "Optimizer", "AdaFactor", "EarlyStopping",
|
"MultipleOptimizer", "Optimizer", "AdaFactor", "EarlyStopping",
|
||||||
"scorers_from_opts"]
|
"scorers_from_opts", "make_batch_align_matrix"]
|
||||||
|
|
|
@ -0,0 +1,139 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from itertools import accumulate
|
||||||
|
|
||||||
|
|
||||||
|
def make_batch_align_matrix(index_tensor, size=None, normalize=False):
|
||||||
|
"""
|
||||||
|
Convert a sparse index_tensor into a batch of alignment matrix,
|
||||||
|
with row normalize to the sum of 1 if set normalize.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_tensor (LongTensor): ``(N, 3)`` of [batch_id, tgt_id, src_id]
|
||||||
|
size (List[int]): Size of the sparse tensor.
|
||||||
|
normalize (bool): if normalize the 2nd dim of resulting tensor.
|
||||||
|
"""
|
||||||
|
n_fill, device = index_tensor.size(0), index_tensor.device
|
||||||
|
value_tensor = torch.ones([n_fill], dtype=torch.float)
|
||||||
|
dense_tensor = torch.sparse_coo_tensor(
|
||||||
|
index_tensor.t(), value_tensor, size=size, device=device).to_dense()
|
||||||
|
if normalize:
|
||||||
|
row_sum = dense_tensor.sum(-1, keepdim=True) # sum by row(tgt)
|
||||||
|
# threshold on 1 to avoid div by 0
|
||||||
|
torch.nn.functional.threshold(row_sum, 1, 1, inplace=True)
|
||||||
|
dense_tensor.div_(row_sum)
|
||||||
|
return dense_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def extract_alignment(align_matrix, tgt_mask, src_lens, n_best):
|
||||||
|
"""
|
||||||
|
Extract a batched align_matrix into its src indice alignment lists,
|
||||||
|
with tgt_mask to filter out invalid tgt position as EOS/PAD.
|
||||||
|
BOS already excluded from tgt_mask in order to match prediction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
align_matrix (Tensor): ``(B, tgt_len, src_len)``,
|
||||||
|
attention head normalized by Softmax(dim=-1)
|
||||||
|
tgt_mask (BoolTensor): ``(B, tgt_len)``, True for EOS, PAD.
|
||||||
|
src_lens (LongTensor): ``(B,)``, containing valid src length
|
||||||
|
n_best (int): a value indicating number of parallel translation.
|
||||||
|
* B: denote flattened batch as B = batch_size * n_best.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
alignments (List[List[FloatTensor|None]]): ``(batch_size, n_best,)``,
|
||||||
|
containing valid alignment matrix (or None if blank prediction)
|
||||||
|
for each translation.
|
||||||
|
"""
|
||||||
|
batch_size_n_best = align_matrix.size(0)
|
||||||
|
assert batch_size_n_best % n_best == 0
|
||||||
|
|
||||||
|
alignments = [[] for _ in range(batch_size_n_best // n_best)]
|
||||||
|
|
||||||
|
# treat alignment matrix one by one as each have different lengths
|
||||||
|
for i, (am_b, tgt_mask_b, src_len) in enumerate(
|
||||||
|
zip(align_matrix, tgt_mask, src_lens)):
|
||||||
|
valid_tgt = ~tgt_mask_b
|
||||||
|
valid_tgt_len = valid_tgt.sum()
|
||||||
|
if valid_tgt_len == 0:
|
||||||
|
# No alignment if not exist valid tgt token
|
||||||
|
valid_alignment = None
|
||||||
|
else:
|
||||||
|
# get valid alignment (sub-matrix from full paded aligment matrix)
|
||||||
|
am_valid_tgt = am_b.masked_select(valid_tgt.unsqueeze(-1)) \
|
||||||
|
.view(valid_tgt_len, -1)
|
||||||
|
valid_alignment = am_valid_tgt[:, :src_len] # only keep valid src
|
||||||
|
alignments[i // n_best].append(valid_alignment)
|
||||||
|
|
||||||
|
return alignments
|
||||||
|
|
||||||
|
|
||||||
|
def build_align_pharaoh(valid_alignment):
|
||||||
|
"""Convert valid alignment matrix to i-j (from 0) Pharaoh format pairs,
|
||||||
|
or empty list if it's None.
|
||||||
|
"""
|
||||||
|
align_pairs = []
|
||||||
|
if isinstance(valid_alignment, torch.Tensor):
|
||||||
|
tgt_align_src_id = valid_alignment.argmax(dim=-1)
|
||||||
|
|
||||||
|
for tgt_id, src_id in enumerate(tgt_align_src_id.tolist()):
|
||||||
|
align_pairs.append(str(src_id) + "-" + str(tgt_id))
|
||||||
|
align_pairs.sort(key=lambda x: int(x.split('-')[-1])) # sort by tgt_id
|
||||||
|
align_pairs.sort(key=lambda x: int(x.split('-')[0])) # sort by src_id
|
||||||
|
return align_pairs
|
||||||
|
|
||||||
|
|
||||||
|
def to_word_align(src, tgt, subword_align, mode):
|
||||||
|
"""Convert subword alignment to word alignment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src (string): tokenized sentence in source language.
|
||||||
|
tgt (string): tokenized sentence in target language.
|
||||||
|
subword_align (string): align_pharaoh correspond to src-tgt.
|
||||||
|
mode (string): tokenization mode used by src and tgt,
|
||||||
|
choose from ["joiner", "spacer"].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
word_align (string): converted alignments correspand to
|
||||||
|
detokenized src-tgt.
|
||||||
|
"""
|
||||||
|
src, tgt = src.strip().split(), tgt.strip().split()
|
||||||
|
subword_align = {(int(a), int(b)) for a, b in (x.split("-")
|
||||||
|
for x in subword_align.split())}
|
||||||
|
if mode == 'joiner':
|
||||||
|
src_map = subword_map_by_joiner(src, marker='■')
|
||||||
|
tgt_map = subword_map_by_joiner(tgt, marker='■')
|
||||||
|
elif mode == 'spacer':
|
||||||
|
src_map = subword_map_by_spacer(src, marker='▁')
|
||||||
|
tgt_map = subword_map_by_spacer(tgt, marker='▁')
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid value for argument mode!")
|
||||||
|
word_align = list({"{}-{}".format(src_map[a], tgt_map[b])
|
||||||
|
for a, b in subword_align})
|
||||||
|
word_align.sort(key=lambda x: int(x.split('-')[-1])) # sort by tgt_id
|
||||||
|
word_align.sort(key=lambda x: int(x.split('-')[0])) # sort by src_id
|
||||||
|
return " ".join(word_align)
|
||||||
|
|
||||||
|
|
||||||
|
def subword_map_by_joiner(subwords, marker='■'):
|
||||||
|
"""Return word id for each subword token (annotate by joiner)."""
|
||||||
|
flags = [0] * len(subwords)
|
||||||
|
for i, tok in enumerate(subwords):
|
||||||
|
if tok.endswith(marker):
|
||||||
|
flags[i] = 1
|
||||||
|
if tok.startswith(marker):
|
||||||
|
assert i >= 1 and flags[i-1] != 1, \
|
||||||
|
"Sentence `{}` not correct!".format(" ".join(subwords))
|
||||||
|
flags[i-1] = 1
|
||||||
|
marker_acc = list(accumulate([0] + flags[:-1]))
|
||||||
|
word_group = [(i - maker_sofar) for i, maker_sofar
|
||||||
|
in enumerate(marker_acc)]
|
||||||
|
return word_group
|
||||||
|
|
||||||
|
|
||||||
|
def subword_map_by_spacer(subwords, marker='▁'):
|
||||||
|
"""Return word id for each subword token (annotate by spacer)."""
|
||||||
|
word_group = list(accumulate([int(marker in x) for x in subwords]))
|
||||||
|
if word_group[0] == 1: # when dummy prefix is set
|
||||||
|
word_group = [item - 1 for item in word_group]
|
||||||
|
return word_group
|
|
@ -2,11 +2,11 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
def init_logger(log_file=None, log_file_level=logging.NOTSET):
|
def init_logger(log_file=None, log_file_level=logging.NOTSET, rotate=False):
|
||||||
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
|
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
@ -16,6 +16,10 @@ def init_logger(log_file=None, log_file_level=logging.NOTSET):
|
||||||
logger.handlers = [console_handler]
|
logger.handlers = [console_handler]
|
||||||
|
|
||||||
if log_file and log_file != '':
|
if log_file and log_file != '':
|
||||||
|
if rotate:
|
||||||
|
file_handler = RotatingFileHandler(
|
||||||
|
log_file, maxBytes=1000000, backupCount=10)
|
||||||
|
else:
|
||||||
file_handler = logging.FileHandler(log_file)
|
file_handler = logging.FileHandler(log_file)
|
||||||
file_handler.setLevel(log_file_level)
|
file_handler.setLevel(log_file_level)
|
||||||
file_handler.setFormatter(log_format)
|
file_handler.setFormatter(log_format)
|
||||||
|
|
|
@ -57,7 +57,8 @@ def build_loss_compute(model, tgt_field, opt, train=True):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
compute = NMTLossCompute(
|
compute = NMTLossCompute(
|
||||||
criterion, loss_gen, lambda_coverage=opt.lambda_coverage)
|
criterion, loss_gen, lambda_coverage=opt.lambda_coverage,
|
||||||
|
lambda_align=opt.lambda_align)
|
||||||
compute.to(device)
|
compute.to(device)
|
||||||
|
|
||||||
return compute
|
return compute
|
||||||
|
@ -226,9 +227,10 @@ class NMTLossCompute(LossComputeBase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, criterion, generator, normalization="sents",
|
def __init__(self, criterion, generator, normalization="sents",
|
||||||
lambda_coverage=0.0):
|
lambda_coverage=0.0, lambda_align=0.0):
|
||||||
super(NMTLossCompute, self).__init__(criterion, generator)
|
super(NMTLossCompute, self).__init__(criterion, generator)
|
||||||
self.lambda_coverage = lambda_coverage
|
self.lambda_coverage = lambda_coverage
|
||||||
|
self.lambda_align = lambda_align
|
||||||
|
|
||||||
def _make_shard_state(self, batch, output, range_, attns=None):
|
def _make_shard_state(self, batch, output, range_, attns=None):
|
||||||
shard_state = {
|
shard_state = {
|
||||||
|
@ -248,10 +250,33 @@ class NMTLossCompute(LossComputeBase):
|
||||||
"std_attn": attns.get("std"),
|
"std_attn": attns.get("std"),
|
||||||
"coverage_attn": coverage
|
"coverage_attn": coverage
|
||||||
})
|
})
|
||||||
|
if self.lambda_align != 0.0:
|
||||||
|
# attn_align should be in (batch_size, pad_tgt_size, pad_src_size)
|
||||||
|
attn_align = attns.get("align", None)
|
||||||
|
# align_idx should be a Tensor in size([N, 3]), N is total number
|
||||||
|
# of align src-tgt pair in current batch, each as
|
||||||
|
# ['sent_N°_in_batch', 'tgt_id+1', 'src_id'] (check AlignField)
|
||||||
|
align_idx = batch.align
|
||||||
|
assert attns is not None
|
||||||
|
assert attn_align is not None, "lambda_align != 0.0 requires " \
|
||||||
|
"alignement attention head"
|
||||||
|
assert align_idx is not None, "lambda_align != 0.0 requires " \
|
||||||
|
"provide guided alignement"
|
||||||
|
pad_tgt_size, batch_size, _ = batch.tgt.size()
|
||||||
|
pad_src_size = batch.src[0].size(0)
|
||||||
|
align_matrix_size = [batch_size, pad_tgt_size, pad_src_size]
|
||||||
|
ref_align = onmt.utils.make_batch_align_matrix(
|
||||||
|
align_idx, align_matrix_size, normalize=True)
|
||||||
|
# NOTE: tgt-src ref alignement that in range_ of shard
|
||||||
|
# (coherent with batch.tgt)
|
||||||
|
shard_state.update({
|
||||||
|
"align_head": attn_align,
|
||||||
|
"ref_align": ref_align[:, range_[0] + 1: range_[1], :]
|
||||||
|
})
|
||||||
return shard_state
|
return shard_state
|
||||||
|
|
||||||
def _compute_loss(self, batch, output, target, std_attn=None,
|
def _compute_loss(self, batch, output, target, std_attn=None,
|
||||||
coverage_attn=None):
|
coverage_attn=None, align_head=None, ref_align=None):
|
||||||
|
|
||||||
bottled_output = self._bottle(output)
|
bottled_output = self._bottle(output)
|
||||||
|
|
||||||
|
@ -263,15 +288,33 @@ class NMTLossCompute(LossComputeBase):
|
||||||
coverage_loss = self._compute_coverage_loss(
|
coverage_loss = self._compute_coverage_loss(
|
||||||
std_attn=std_attn, coverage_attn=coverage_attn)
|
std_attn=std_attn, coverage_attn=coverage_attn)
|
||||||
loss += coverage_loss
|
loss += coverage_loss
|
||||||
|
if self.lambda_align != 0.0:
|
||||||
|
if align_head.dtype != loss.dtype: # Fix FP16
|
||||||
|
align_head = align_head.to(loss.dtype)
|
||||||
|
if ref_align.dtype != loss.dtype:
|
||||||
|
ref_align = ref_align.to(loss.dtype)
|
||||||
|
align_loss = self._compute_alignement_loss(
|
||||||
|
align_head=align_head, ref_align=ref_align)
|
||||||
|
loss += align_loss
|
||||||
stats = self._stats(loss.clone(), scores, gtruth)
|
stats = self._stats(loss.clone(), scores, gtruth)
|
||||||
|
|
||||||
return loss, stats
|
return loss, stats
|
||||||
|
|
||||||
def _compute_coverage_loss(self, std_attn, coverage_attn):
|
def _compute_coverage_loss(self, std_attn, coverage_attn):
|
||||||
covloss = torch.min(std_attn, coverage_attn).sum(2).view(-1)
|
covloss = torch.min(std_attn, coverage_attn).sum()
|
||||||
covloss *= self.lambda_coverage
|
covloss *= self.lambda_coverage
|
||||||
return covloss
|
return covloss
|
||||||
|
|
||||||
|
def _compute_alignement_loss(self, align_head, ref_align):
|
||||||
|
"""Compute loss between 2 partial alignment matrix."""
|
||||||
|
# align_head contains value in [0, 1) presenting attn prob,
|
||||||
|
# 0 was resulted by the context attention src_pad_mask
|
||||||
|
# So, the correspand position in ref_align should also be 0
|
||||||
|
# Therefore, clip align_head to > 1e-18 should be bias free.
|
||||||
|
align_loss = -align_head.clamp(min=1e-18).log().mul(ref_align).sum()
|
||||||
|
align_loss *= self.lambda_align
|
||||||
|
return align_loss
|
||||||
|
|
||||||
|
|
||||||
def filter_shard_state(state, shard_size=None):
|
def filter_shard_state(state, shard_size=None):
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
|
|
|
@ -3,10 +3,23 @@
|
||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
import inspect
|
import inspect
|
||||||
from itertools import islice
|
from itertools import islice, repeat
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def split_corpus(path, shard_size):
|
def split_corpus(path, shard_size, default=None):
|
||||||
|
"""yield a `list` containing `shard_size` line of `path`,
|
||||||
|
or repeatly generate `default` if `path` is None.
|
||||||
|
"""
|
||||||
|
if path is not None:
|
||||||
|
return _split_corpus(path, shard_size)
|
||||||
|
else:
|
||||||
|
return repeat(default)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_corpus(path, shard_size):
|
||||||
|
"""Yield a `list` containing `shard_size` line of `path`.
|
||||||
|
"""
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
if shard_size <= 0:
|
if shard_size <= 0:
|
||||||
yield f.readlines()
|
yield f.readlines()
|
||||||
|
@ -124,3 +137,37 @@ def relative_matmul(x, z, transpose):
|
||||||
def fn_args(fun):
|
def fn_args(fun):
|
||||||
"""Returns the list of function arguments name."""
|
"""Returns the list of function arguments name."""
|
||||||
return inspect.getfullargspec(fun).args
|
return inspect.getfullargspec(fun).args
|
||||||
|
|
||||||
|
|
||||||
|
def report_matrix(row_label, column_label, matrix):
|
||||||
|
header_format = "{:>10.10} " + "{:>10.7} " * len(row_label)
|
||||||
|
row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label)
|
||||||
|
output = header_format.format("", *row_label) + '\n'
|
||||||
|
for word, row in zip(column_label, matrix):
|
||||||
|
max_index = row.index(max(row))
|
||||||
|
row_format = row_format.replace(
|
||||||
|
"{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
|
||||||
|
row_format = row_format.replace(
|
||||||
|
"{:*>10.7f} ", "{:>10.7f} ", max_index)
|
||||||
|
output += row_format.format(word, *row) + '\n'
|
||||||
|
row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_config(model_config, root):
|
||||||
|
# we need to check the model path + any tokenizer path
|
||||||
|
for model in model_config["models"]:
|
||||||
|
model_path = os.path.join(root, model)
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"{} from model {} does not exist".format(
|
||||||
|
model_path, model_config["id"]))
|
||||||
|
if "tokenizer" in model_config.keys():
|
||||||
|
if "params" in model_config["tokenizer"].keys():
|
||||||
|
for k, v in model_config["tokenizer"]["params"].items():
|
||||||
|
if k.endswith("path"):
|
||||||
|
tok_path = os.path.join(root, v)
|
||||||
|
if not os.path.exists(tok_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"{} from model {} does not exist".format(
|
||||||
|
tok_path, model_config["id"]))
|
||||||
|
|
|
@ -6,6 +6,9 @@ import operator
|
||||||
import functools
|
import functools
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
|
import types
|
||||||
|
import importlib
|
||||||
|
from onmt.utils.misc import fn_args
|
||||||
|
|
||||||
|
|
||||||
def build_torch_optimizer(model, opt):
|
def build_torch_optimizer(model, opt):
|
||||||
|
@ -75,8 +78,8 @@ def build_torch_optimizer(model, opt):
|
||||||
betas=betas,
|
betas=betas,
|
||||||
eps=1e-8)])
|
eps=1e-8)])
|
||||||
elif opt.optim == 'fusedadam':
|
elif opt.optim == 'fusedadam':
|
||||||
import apex
|
# we use here a FusedAdam() copy of an old Apex repo
|
||||||
optimizer = apex.optimizers.FusedAdam(
|
optimizer = FusedAdam(
|
||||||
params,
|
params,
|
||||||
lr=opt.learning_rate,
|
lr=opt.learning_rate,
|
||||||
betas=betas)
|
betas=betas)
|
||||||
|
@ -85,14 +88,23 @@ def build_torch_optimizer(model, opt):
|
||||||
|
|
||||||
if opt.model_dtype == 'fp16':
|
if opt.model_dtype == 'fp16':
|
||||||
import apex
|
import apex
|
||||||
|
if opt.optim != 'fusedadam':
|
||||||
|
# In this case use the new AMP API from apex
|
||||||
loss_scale = "dynamic" if opt.loss_scale == 0 else opt.loss_scale
|
loss_scale = "dynamic" if opt.loss_scale == 0 else opt.loss_scale
|
||||||
model, optimizer = apex.amp.initialize(
|
model, optimizer = apex.amp.initialize(
|
||||||
[model, model.generator],
|
[model, model.generator],
|
||||||
optimizer,
|
optimizer,
|
||||||
opt_level=opt.apex_opt_level,
|
opt_level=opt.apex_opt_level,
|
||||||
loss_scale=loss_scale,
|
loss_scale=loss_scale,
|
||||||
keep_batchnorm_fp32=False if opt.optim == "fusedadam" else None)
|
keep_batchnorm_fp32=None)
|
||||||
|
else:
|
||||||
|
# In this case use the old FusedAdam with FP16_optimizer wrapper
|
||||||
|
static_loss_scale = opt.loss_scale
|
||||||
|
dynamic_loss_scale = opt.loss_scale == 0
|
||||||
|
optimizer = apex.contrib.optimizers.FP16_Optimizer(
|
||||||
|
optimizer,
|
||||||
|
static_loss_scale=static_loss_scale,
|
||||||
|
dynamic_loss_scale=dynamic_loss_scale)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -222,8 +234,7 @@ class Optimizer(object):
|
||||||
self._max_grad_norm = max_grad_norm or 0
|
self._max_grad_norm = max_grad_norm or 0
|
||||||
self._training_step = 1
|
self._training_step = 1
|
||||||
self._decay_step = 1
|
self._decay_step = 1
|
||||||
self._with_fp16_wrapper = (
|
self._fp16 = None
|
||||||
optimizer.__class__.__name__ == "FP16_Optimizer")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_opt(cls, model, opt, checkpoint=None):
|
def from_opt(cls, model, opt, checkpoint=None):
|
||||||
|
@ -273,6 +284,11 @@ class Optimizer(object):
|
||||||
optim_opt.learning_rate,
|
optim_opt.learning_rate,
|
||||||
learning_rate_decay_fn=make_learning_rate_decay_fn(optim_opt),
|
learning_rate_decay_fn=make_learning_rate_decay_fn(optim_opt),
|
||||||
max_grad_norm=optim_opt.max_grad_norm)
|
max_grad_norm=optim_opt.max_grad_norm)
|
||||||
|
if opt.model_dtype == "fp16":
|
||||||
|
if opt.optim == "fusedadam":
|
||||||
|
optimizer._fp16 = "legacy"
|
||||||
|
else:
|
||||||
|
optimizer._fp16 = "amp"
|
||||||
if optim_state_dict:
|
if optim_state_dict:
|
||||||
optimizer.load_state_dict(optim_state_dict)
|
optimizer.load_state_dict(optim_state_dict)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
@ -311,10 +327,15 @@ class Optimizer(object):
|
||||||
def backward(self, loss):
|
def backward(self, loss):
|
||||||
"""Wrapper for backward pass. Some optimizer requires ownership of the
|
"""Wrapper for backward pass. Some optimizer requires ownership of the
|
||||||
backward pass."""
|
backward pass."""
|
||||||
if self._with_fp16_wrapper:
|
if self._fp16 == "amp":
|
||||||
import apex
|
import apex
|
||||||
with apex.amp.scale_loss(loss, self._optimizer) as scaled_loss:
|
with apex.amp.scale_loss(loss, self._optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
|
elif self._fp16 == "legacy":
|
||||||
|
kwargs = {}
|
||||||
|
if "update_master_grads" in fn_args(self._optimizer.backward):
|
||||||
|
kwargs["update_master_grads"] = True
|
||||||
|
self._optimizer.backward(loss, **kwargs)
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
@ -325,17 +346,16 @@ class Optimizer(object):
|
||||||
rate.
|
rate.
|
||||||
"""
|
"""
|
||||||
learning_rate = self.learning_rate()
|
learning_rate = self.learning_rate()
|
||||||
if self._with_fp16_wrapper:
|
if self._fp16 == "legacy":
|
||||||
if hasattr(self._optimizer, "update_master_grads"):
|
if hasattr(self._optimizer, "update_master_grads"):
|
||||||
self._optimizer.update_master_grads()
|
self._optimizer.update_master_grads()
|
||||||
if hasattr(self._optimizer, "clip_master_grads") and \
|
if hasattr(self._optimizer, "clip_master_grads") and \
|
||||||
self._max_grad_norm > 0:
|
self._max_grad_norm > 0:
|
||||||
import apex
|
self._optimizer.clip_master_grads(self._max_grad_norm)
|
||||||
torch.nn.utils.clip_grad_norm_(
|
|
||||||
apex.amp.master_params(self), self._max_grad_norm)
|
|
||||||
for group in self._optimizer.param_groups:
|
for group in self._optimizer.param_groups:
|
||||||
group['lr'] = learning_rate
|
group['lr'] = learning_rate
|
||||||
if not self._with_fp16_wrapper and self._max_grad_norm > 0:
|
if self._fp16 is None and self._max_grad_norm > 0:
|
||||||
clip_grad_norm_(group['params'], self._max_grad_norm)
|
clip_grad_norm_(group['params'], self._max_grad_norm)
|
||||||
self._optimizer.step()
|
self._optimizer.step()
|
||||||
self._decay_step += 1
|
self._decay_step += 1
|
||||||
|
@ -513,3 +533,156 @@ class AdaFactor(torch.optim.Optimizer):
|
||||||
p.data.add_(-group['weight_decay'] * lr_t, p.data)
|
p.data.add_(-group['weight_decay'] * lr_t, p.data)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAdam(torch.optim.Optimizer):
|
||||||
|
|
||||||
|
"""Implements Adam algorithm. Currently GPU-only.
|
||||||
|
Requires Apex to be installed via
|
||||||
|
``python setup.py install --cuda_ext --cpp_ext``.
|
||||||
|
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||||
|
Arguments:
|
||||||
|
params (iterable): iterable of parameters to optimize or dicts defining
|
||||||
|
parameter groups.
|
||||||
|
lr (float, optional): learning rate. (default: 1e-3)
|
||||||
|
betas (Tuple[float, float], optional): coefficients used for computing
|
||||||
|
running averages of gradient and its square.
|
||||||
|
(default: (0.9, 0.999))
|
||||||
|
eps (float, optional): term added to the denominator to improve
|
||||||
|
numerical stability. (default: 1e-8)
|
||||||
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||||
|
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||||
|
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||||
|
(default: False) NOT SUPPORTED in FusedAdam!
|
||||||
|
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
|
||||||
|
adds eps to the bias-corrected second moment estimate before
|
||||||
|
evaluating square root instead of adding it to the square root of
|
||||||
|
second moment estimate as in the original paper. (default: False)
|
||||||
|
.. _Adam: A Method for Stochastic Optimization:
|
||||||
|
https://arxiv.org/abs/1412.6980
|
||||||
|
.. _On the Convergence of Adam and Beyond:
|
||||||
|
https://openreview.net/forum?id=ryQu7f-RZ
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params,
|
||||||
|
lr=1e-3, bias_correction=True,
|
||||||
|
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt=False,
|
||||||
|
weight_decay=0., max_grad_norm=0., amsgrad=False):
|
||||||
|
global fused_adam_cuda
|
||||||
|
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
|
||||||
|
|
||||||
|
if amsgrad:
|
||||||
|
raise RuntimeError('AMSGrad variant not supported.')
|
||||||
|
defaults = dict(lr=lr, bias_correction=bias_correction,
|
||||||
|
betas=betas, eps=eps, weight_decay=weight_decay,
|
||||||
|
max_grad_norm=max_grad_norm)
|
||||||
|
super(FusedAdam, self).__init__(params, defaults)
|
||||||
|
self.eps_mode = 0 if eps_inside_sqrt else 1
|
||||||
|
|
||||||
|
def step(self, closure=None, grads=None, output_params=None,
|
||||||
|
scale=1., grad_norms=None):
|
||||||
|
"""Performs a single optimization step.
|
||||||
|
Arguments:
|
||||||
|
closure (callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
grads (list of tensors, optional): weight gradient to use for the
|
||||||
|
optimizer update. If gradients have type torch.half, parameters
|
||||||
|
are expected to be in type torch.float. (default: None)
|
||||||
|
output params (list of tensors, optional): A reduced precision copy
|
||||||
|
of the updated weights written out in addition to the regular
|
||||||
|
updated weights. Have to be of same type as gradients.
|
||||||
|
(default: None)
|
||||||
|
scale (float, optional): factor to divide gradient tensor values
|
||||||
|
by before applying to weights. (default: 1)
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
if grads is None:
|
||||||
|
grads_group = [None]*len(self.param_groups)
|
||||||
|
# backward compatibility
|
||||||
|
# assuming a list/generator of parameter means single group
|
||||||
|
elif isinstance(grads, types.GeneratorType):
|
||||||
|
grads_group = [grads]
|
||||||
|
elif type(grads[0]) != list:
|
||||||
|
grads_group = [grads]
|
||||||
|
else:
|
||||||
|
grads_group = grads
|
||||||
|
|
||||||
|
if output_params is None:
|
||||||
|
output_params_group = [None]*len(self.param_groups)
|
||||||
|
elif isinstance(output_params, types.GeneratorType):
|
||||||
|
output_params_group = [output_params]
|
||||||
|
elif type(output_params[0]) != list:
|
||||||
|
output_params_group = [output_params]
|
||||||
|
else:
|
||||||
|
output_params_group = output_params
|
||||||
|
|
||||||
|
if grad_norms is None:
|
||||||
|
grad_norms = [None]*len(self.param_groups)
|
||||||
|
|
||||||
|
for group, grads_this_group, output_params_this_group, \
|
||||||
|
grad_norm in zip(self.param_groups, grads_group,
|
||||||
|
output_params_group, grad_norms):
|
||||||
|
if grads_this_group is None:
|
||||||
|
grads_this_group = [None]*len(group['params'])
|
||||||
|
if output_params_this_group is None:
|
||||||
|
output_params_this_group = [None]*len(group['params'])
|
||||||
|
|
||||||
|
# compute combined scale factor for this group
|
||||||
|
combined_scale = scale
|
||||||
|
if group['max_grad_norm'] > 0:
|
||||||
|
# norm is in fact norm*scale
|
||||||
|
clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
|
||||||
|
if clip > 1:
|
||||||
|
combined_scale = clip * scale
|
||||||
|
|
||||||
|
bias_correction = 1 if group['bias_correction'] else 0
|
||||||
|
|
||||||
|
for p, grad, output_param in zip(group['params'],
|
||||||
|
grads_this_group,
|
||||||
|
output_params_this_group):
|
||||||
|
# note: p.grad should not ever be set for correct operation of
|
||||||
|
# mixed precision optimizer that sometimes sends None gradients
|
||||||
|
if p.grad is None and grad is None:
|
||||||
|
continue
|
||||||
|
if grad is None:
|
||||||
|
grad = p.grad.data
|
||||||
|
if grad.is_sparse:
|
||||||
|
raise RuntimeError('FusedAdam does not support sparse \
|
||||||
|
gradients, please consider \
|
||||||
|
SparseAdam instead')
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
# State initialization
|
||||||
|
if len(state) == 0:
|
||||||
|
state['step'] = 0
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state['exp_avg'] = torch.zeros_like(p.data)
|
||||||
|
# Exponential moving average of squared gradient values
|
||||||
|
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||||
|
|
||||||
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||||
|
beta1, beta2 = group['betas']
|
||||||
|
|
||||||
|
state['step'] += 1
|
||||||
|
|
||||||
|
out_p = torch.tensor([], dtype=torch.float) if output_param \
|
||||||
|
is None else output_param
|
||||||
|
fused_adam_cuda.adam(p.data,
|
||||||
|
out_p,
|
||||||
|
exp_avg,
|
||||||
|
exp_avg_sq,
|
||||||
|
grad,
|
||||||
|
group['lr'],
|
||||||
|
beta1,
|
||||||
|
beta2,
|
||||||
|
group['eps'],
|
||||||
|
combined_scale,
|
||||||
|
state['step'],
|
||||||
|
self.eps_mode,
|
||||||
|
bias_correction,
|
||||||
|
group['weight_decay'])
|
||||||
|
return loss
|
||||||
|
|
|
@ -46,6 +46,11 @@ class ArgumentParser(cfargparse.ArgumentParser):
|
||||||
if model_opt.copy_attn_type is None:
|
if model_opt.copy_attn_type is None:
|
||||||
model_opt.copy_attn_type = model_opt.global_attention
|
model_opt.copy_attn_type = model_opt.global_attention
|
||||||
|
|
||||||
|
if model_opt.alignment_layer is None:
|
||||||
|
model_opt.alignment_layer = -2
|
||||||
|
model_opt.lambda_align = 0.0
|
||||||
|
model_opt.full_context_alignment = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_model_opts(cls, model_opt):
|
def validate_model_opts(cls, model_opt):
|
||||||
assert model_opt.model_type in ["text", "img", "audio", "vec"], \
|
assert model_opt.model_type in ["text", "img", "audio", "vec"], \
|
||||||
|
@ -63,6 +68,17 @@ class ArgumentParser(cfargparse.ArgumentParser):
|
||||||
if model_opt.model_type != "text":
|
if model_opt.model_type != "text":
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"--share_embeddings requires --model_type text.")
|
"--share_embeddings requires --model_type text.")
|
||||||
|
if model_opt.lambda_align > 0.0:
|
||||||
|
assert model_opt.decoder_type == 'transformer', \
|
||||||
|
"Only transformer is supported to joint learn alignment."
|
||||||
|
assert model_opt.alignment_layer < model_opt.dec_layers and \
|
||||||
|
model_opt.alignment_layer >= -model_opt.dec_layers, \
|
||||||
|
"N° alignment_layer should be smaller than number of layers."
|
||||||
|
logger.info("Joint learn alignment at layer [{}] "
|
||||||
|
"with {} heads in full_context '{}'.".format(
|
||||||
|
model_opt.alignment_layer,
|
||||||
|
model_opt.alignment_heads,
|
||||||
|
model_opt.full_context_alignment))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def ckpt_model_opts(cls, ckpt_opt):
|
def ckpt_model_opts(cls, ckpt_opt):
|
||||||
|
@ -85,8 +101,7 @@ class ArgumentParser(cfargparse.ArgumentParser):
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"gpuid is deprecated see world_size and gpu_ranks")
|
"gpuid is deprecated see world_size and gpu_ranks")
|
||||||
if torch.cuda.is_available() and not opt.gpu_ranks:
|
if torch.cuda.is_available() and not opt.gpu_ranks:
|
||||||
logger.info("WARNING: You have a CUDA device, \
|
logger.warn("You have a CUDA device, should run with -gpu_ranks")
|
||||||
should run with -gpu_ranks")
|
|
||||||
if opt.world_size < len(opt.gpu_ranks):
|
if opt.world_size < len(opt.gpu_ranks):
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"parameter counts of -gpu_ranks must be less or equal "
|
"parameter counts of -gpu_ranks must be less or equal "
|
||||||
|
@ -128,6 +143,18 @@ class ArgumentParser(cfargparse.ArgumentParser):
|
||||||
for file in opt.train_src + opt.train_tgt:
|
for file in opt.train_src + opt.train_tgt:
|
||||||
assert os.path.isfile(file), "Please check path of %s" % file
|
assert os.path.isfile(file), "Please check path of %s" % file
|
||||||
|
|
||||||
|
if len(opt.train_align) == 1 and opt.train_align[0] is None:
|
||||||
|
opt.train_align = [None] * len(opt.train_src)
|
||||||
|
else:
|
||||||
|
assert len(opt.train_align) == len(opt.train_src), \
|
||||||
|
"Please provide same number of word alignment train \
|
||||||
|
files as src/tgt!"
|
||||||
|
for file in opt.train_align:
|
||||||
|
assert os.path.isfile(file), "Please check path of %s" % file
|
||||||
|
|
||||||
|
assert not opt.valid_align or os.path.isfile(opt.valid_align), \
|
||||||
|
"Please check path of your valid alignment file!"
|
||||||
|
|
||||||
assert not opt.valid_src or os.path.isfile(opt.valid_src), \
|
assert not opt.valid_src or os.path.isfile(opt.valid_src), \
|
||||||
"Please check path of your valid src file!"
|
"Please check path of your valid src file!"
|
||||||
assert not opt.valid_tgt or os.path.isfile(opt.valid_tgt), \
|
assert not opt.valid_tgt or os.path.isfile(opt.valid_tgt), \
|
||||||
|
|
|
@ -8,16 +8,15 @@ import onmt
|
||||||
from onmt.utils.logging import logger
|
from onmt.utils.logging import logger
|
||||||
|
|
||||||
|
|
||||||
def build_report_manager(opt):
|
def build_report_manager(opt, gpu_rank):
|
||||||
if opt.tensorboard:
|
if opt.tensorboard and gpu_rank == 0:
|
||||||
from tensorboardX import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
tensorboard_log_dir = opt.tensorboard_log_dir
|
tensorboard_log_dir = opt.tensorboard_log_dir
|
||||||
|
|
||||||
if not opt.train_from:
|
if not opt.train_from:
|
||||||
tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S")
|
tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S")
|
||||||
|
|
||||||
writer = SummaryWriter(tensorboard_log_dir,
|
writer = SummaryWriter(tensorboard_log_dir, comment="Unmt")
|
||||||
comment="Unmt")
|
|
||||||
else:
|
else:
|
||||||
writer = None
|
writer = None
|
||||||
|
|
||||||
|
@ -42,7 +41,6 @@ class ReportMgrBase(object):
|
||||||
means that you will need to set it later or use `start()`
|
means that you will need to set it later or use `start()`
|
||||||
"""
|
"""
|
||||||
self.report_every = report_every
|
self.report_every = report_every
|
||||||
self.progress_step = 0
|
|
||||||
self.start_time = start_time
|
self.start_time = start_time
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
|
@ -75,7 +73,6 @@ class ReportMgrBase(object):
|
||||||
onmt.utils.Statistics.all_gather_stats(report_stats)
|
onmt.utils.Statistics.all_gather_stats(report_stats)
|
||||||
self._report_training(
|
self._report_training(
|
||||||
step, num_steps, learning_rate, report_stats)
|
step, num_steps, learning_rate, report_stats)
|
||||||
self.progress_step += 1
|
|
||||||
return onmt.utils.Statistics()
|
return onmt.utils.Statistics()
|
||||||
else:
|
else:
|
||||||
return report_stats
|
return report_stats
|
||||||
|
@ -127,11 +124,10 @@ class ReportMgr(ReportMgrBase):
|
||||||
report_stats.output(step, num_steps,
|
report_stats.output(step, num_steps,
|
||||||
learning_rate, self.start_time)
|
learning_rate, self.start_time)
|
||||||
|
|
||||||
# Log the progress using the number of batches on the x-axis.
|
|
||||||
self.maybe_log_tensorboard(report_stats,
|
self.maybe_log_tensorboard(report_stats,
|
||||||
"progress",
|
"progress",
|
||||||
learning_rate,
|
learning_rate,
|
||||||
self.progress_step)
|
step)
|
||||||
report_stats = onmt.utils.Statistics()
|
report_stats = onmt.utils.Statistics()
|
||||||
|
|
||||||
return report_stats
|
return report_stats
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
import onmt.opts as opts
|
||||||
|
from onmt.utils.parse import ArgumentParser
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
dir_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
opts.config_opts(parser)
|
||||||
|
opts.model_opts(parser)
|
||||||
|
opts.global_opts(parser)
|
||||||
|
opt = parser.parse_args()
|
||||||
|
with open(os.path.join(dir_path, 'opt_data'), 'wb') as f:
|
||||||
|
pickle.dump(opt, f)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,220 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
from onmt.bin.preprocess import main
|
||||||
"""
|
|
||||||
Pre-process Data / features files and build vocabulary
|
|
||||||
"""
|
|
||||||
import codecs
|
|
||||||
import glob
|
|
||||||
import sys
|
|
||||||
import gc
|
|
||||||
import torch
|
|
||||||
from functools import partial
|
|
||||||
from collections import Counter, defaultdict
|
|
||||||
|
|
||||||
from onmt.utils.logging import init_logger, logger
|
|
||||||
from onmt.utils.misc import split_corpus
|
|
||||||
import onmt.inputters as inputters
|
|
||||||
import onmt.opts as opts
|
|
||||||
from onmt.utils.parse import ArgumentParser
|
|
||||||
from onmt.inputters.inputter import _build_fields_vocab,\
|
|
||||||
_load_vocab
|
|
||||||
|
|
||||||
|
|
||||||
def check_existing_pt_files(opt):
|
|
||||||
""" Check if there are existing .pt files to avoid overwriting them """
|
|
||||||
pattern = opt.save_data + '.{}*.pt'
|
|
||||||
for t in ['train', 'valid']:
|
|
||||||
path = pattern.format(t)
|
|
||||||
if glob.glob(path):
|
|
||||||
sys.stderr.write("Please backup existing pt files: %s, "
|
|
||||||
"to avoid overwriting them!\n" % path)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt):
|
|
||||||
assert corpus_type in ['train', 'valid']
|
|
||||||
|
|
||||||
if corpus_type == 'train':
|
|
||||||
counters = defaultdict(Counter)
|
|
||||||
srcs = opt.train_src
|
|
||||||
tgts = opt.train_tgt
|
|
||||||
ids = opt.train_ids
|
|
||||||
else:
|
|
||||||
srcs = [opt.valid_src]
|
|
||||||
tgts = [opt.valid_tgt]
|
|
||||||
ids = [None]
|
|
||||||
|
|
||||||
for src, tgt, maybe_id in zip(srcs, tgts, ids):
|
|
||||||
logger.info("Reading source and target files: %s %s." % (src, tgt))
|
|
||||||
|
|
||||||
src_shards = split_corpus(src, opt.shard_size)
|
|
||||||
tgt_shards = split_corpus(tgt, opt.shard_size)
|
|
||||||
shard_pairs = zip(src_shards, tgt_shards)
|
|
||||||
dataset_paths = []
|
|
||||||
if (corpus_type == "train" or opt.filter_valid) and tgt is not None:
|
|
||||||
filter_pred = partial(
|
|
||||||
inputters.filter_example, use_src_len=opt.data_type == "text",
|
|
||||||
max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length)
|
|
||||||
else:
|
|
||||||
filter_pred = None
|
|
||||||
|
|
||||||
if corpus_type == "train":
|
|
||||||
existing_fields = None
|
|
||||||
if opt.src_vocab != "":
|
|
||||||
try:
|
|
||||||
logger.info("Using existing vocabulary...")
|
|
||||||
existing_fields = torch.load(opt.src_vocab)
|
|
||||||
except torch.serialization.pickle.UnpicklingError:
|
|
||||||
logger.info("Building vocab from text file...")
|
|
||||||
src_vocab, src_vocab_size = _load_vocab(
|
|
||||||
opt.src_vocab, "src", counters,
|
|
||||||
opt.src_words_min_frequency)
|
|
||||||
else:
|
|
||||||
src_vocab = None
|
|
||||||
|
|
||||||
if opt.tgt_vocab != "":
|
|
||||||
tgt_vocab, tgt_vocab_size = _load_vocab(
|
|
||||||
opt.tgt_vocab, "tgt", counters,
|
|
||||||
opt.tgt_words_min_frequency)
|
|
||||||
else:
|
|
||||||
tgt_vocab = None
|
|
||||||
|
|
||||||
for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
|
|
||||||
assert len(src_shard) == len(tgt_shard)
|
|
||||||
logger.info("Building shard %d." % i)
|
|
||||||
dataset = inputters.Dataset(
|
|
||||||
fields,
|
|
||||||
readers=([src_reader, tgt_reader]
|
|
||||||
if tgt_reader else [src_reader]),
|
|
||||||
data=([("src", src_shard), ("tgt", tgt_shard)]
|
|
||||||
if tgt_reader else [("src", src_shard)]),
|
|
||||||
dirs=([opt.src_dir, None]
|
|
||||||
if tgt_reader else [opt.src_dir]),
|
|
||||||
sort_key=inputters.str2sortkey[opt.data_type],
|
|
||||||
filter_pred=filter_pred
|
|
||||||
)
|
|
||||||
if corpus_type == "train" and existing_fields is None:
|
|
||||||
for ex in dataset.examples:
|
|
||||||
for name, field in fields.items():
|
|
||||||
try:
|
|
||||||
f_iter = iter(field)
|
|
||||||
except TypeError:
|
|
||||||
f_iter = [(name, field)]
|
|
||||||
all_data = [getattr(ex, name, None)]
|
|
||||||
else:
|
|
||||||
all_data = getattr(ex, name)
|
|
||||||
for (sub_n, sub_f), fd in zip(
|
|
||||||
f_iter, all_data):
|
|
||||||
has_vocab = (sub_n == 'src' and
|
|
||||||
src_vocab is not None) or \
|
|
||||||
(sub_n == 'tgt' and
|
|
||||||
tgt_vocab is not None)
|
|
||||||
if (hasattr(sub_f, 'sequential')
|
|
||||||
and sub_f.sequential and not has_vocab):
|
|
||||||
val = fd
|
|
||||||
counters[sub_n].update(val)
|
|
||||||
if maybe_id:
|
|
||||||
shard_base = corpus_type + "_" + maybe_id
|
|
||||||
else:
|
|
||||||
shard_base = corpus_type
|
|
||||||
data_path = "{:s}.{:s}.{:d}.pt".\
|
|
||||||
format(opt.save_data, shard_base, i)
|
|
||||||
dataset_paths.append(data_path)
|
|
||||||
|
|
||||||
logger.info(" * saving %sth %s data shard to %s."
|
|
||||||
% (i, shard_base, data_path))
|
|
||||||
|
|
||||||
dataset.save(data_path)
|
|
||||||
|
|
||||||
del dataset.examples
|
|
||||||
gc.collect()
|
|
||||||
del dataset
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
if corpus_type == "train":
|
|
||||||
vocab_path = opt.save_data + '.vocab.pt'
|
|
||||||
if existing_fields is None:
|
|
||||||
fields = _build_fields_vocab(
|
|
||||||
fields, counters, opt.data_type,
|
|
||||||
opt.share_vocab, opt.vocab_size_multiple,
|
|
||||||
opt.src_vocab_size, opt.src_words_min_frequency,
|
|
||||||
opt.tgt_vocab_size, opt.tgt_words_min_frequency)
|
|
||||||
else:
|
|
||||||
fields = existing_fields
|
|
||||||
torch.save(fields, vocab_path)
|
|
||||||
|
|
||||||
|
|
||||||
def build_save_vocab(train_dataset, fields, opt):
|
|
||||||
fields = inputters.build_vocab(
|
|
||||||
train_dataset, fields, opt.data_type, opt.share_vocab,
|
|
||||||
opt.src_vocab, opt.src_vocab_size, opt.src_words_min_frequency,
|
|
||||||
opt.tgt_vocab, opt.tgt_vocab_size, opt.tgt_words_min_frequency,
|
|
||||||
vocab_size_multiple=opt.vocab_size_multiple
|
|
||||||
)
|
|
||||||
vocab_path = opt.save_data + '.vocab.pt'
|
|
||||||
torch.save(fields, vocab_path)
|
|
||||||
|
|
||||||
|
|
||||||
def count_features(path):
|
|
||||||
"""
|
|
||||||
path: location of a corpus file with whitespace-delimited tokens and
|
|
||||||
│-delimited features within the token
|
|
||||||
returns: the number of features in the dataset
|
|
||||||
"""
|
|
||||||
with codecs.open(path, "r", "utf-8") as f:
|
|
||||||
first_tok = f.readline().split(None, 1)[0]
|
|
||||||
return len(first_tok.split(u"│")) - 1
|
|
||||||
|
|
||||||
|
|
||||||
def main(opt):
|
|
||||||
ArgumentParser.validate_preprocess_args(opt)
|
|
||||||
torch.manual_seed(opt.seed)
|
|
||||||
if not(opt.overwrite):
|
|
||||||
check_existing_pt_files(opt)
|
|
||||||
|
|
||||||
init_logger(opt.log_file)
|
|
||||||
logger.info("Extracting features...")
|
|
||||||
|
|
||||||
src_nfeats = 0
|
|
||||||
tgt_nfeats = 0
|
|
||||||
for src, tgt in zip(opt.train_src, opt.train_tgt):
|
|
||||||
src_nfeats += count_features(src) if opt.data_type == 'text' \
|
|
||||||
else 0
|
|
||||||
tgt_nfeats += count_features(tgt) # tgt always text so far
|
|
||||||
logger.info(" * number of source features: %d." % src_nfeats)
|
|
||||||
logger.info(" * number of target features: %d." % tgt_nfeats)
|
|
||||||
|
|
||||||
logger.info("Building `Fields` object...")
|
|
||||||
fields = inputters.get_fields(
|
|
||||||
opt.data_type,
|
|
||||||
src_nfeats,
|
|
||||||
tgt_nfeats,
|
|
||||||
dynamic_dict=opt.dynamic_dict,
|
|
||||||
src_truncate=opt.src_seq_length_trunc,
|
|
||||||
tgt_truncate=opt.tgt_seq_length_trunc)
|
|
||||||
|
|
||||||
src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
|
|
||||||
tgt_reader = inputters.str2reader["text"].from_opt(opt)
|
|
||||||
|
|
||||||
logger.info("Building & saving training data...")
|
|
||||||
build_save_dataset(
|
|
||||||
'train', fields, src_reader, tgt_reader, opt)
|
|
||||||
|
|
||||||
if opt.valid_src and opt.valid_tgt:
|
|
||||||
logger.info("Building & saving validation data...")
|
|
||||||
build_save_dataset('valid', fields, src_reader, tgt_reader, opt)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_parser():
|
|
||||||
parser = ArgumentParser(description='preprocess.py')
|
|
||||||
|
|
||||||
opts.config_opts(parser)
|
|
||||||
opts.preprocess_opts(parser)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = _get_parser()
|
main()
|
||||||
|
|
||||||
opt = parser.parse_args()
|
|
||||||
main(opt)
|
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
cffi
|
cffi
|
||||||
torchvision==0.2.1
|
torchvision
|
||||||
joblib
|
joblib
|
||||||
librosa
|
librosa
|
||||||
Pillow
|
Pillow
|
||||||
git+git://github.com/pytorch/audio.git@d92de5b97fc6204db4b1e3ed20c03ac06f5d53f0
|
git+git://github.com/pytorch/audio.git@d92de5b97fc6204db4b1e3ed20c03ac06f5d53f0
|
||||||
pyrouge
|
pyrouge
|
||||||
pyonmttok
|
|
||||||
opencv-python
|
opencv-python
|
||||||
git+https://github.com/NVIDIA/apex
|
git+https://github.com/NVIDIA/apex
|
||||||
flask
|
pretrainedmodels
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
six
|
|
||||||
tqdm==4.30.*
|
|
||||||
torch>=1.1
|
|
||||||
git+https://github.com/pytorch/text.git@master#wheel=torchtext
|
|
||||||
future
|
|
||||||
configargparse
|
|
|
@ -1,129 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import configargparse
|
from onmt.bin.server import main
|
||||||
|
|
||||||
from flask import Flask, jsonify, request
|
|
||||||
from onmt.translate import TranslationServer, ServerModelError
|
|
||||||
|
|
||||||
STATUS_OK = "ok"
|
|
||||||
STATUS_ERROR = "error"
|
|
||||||
|
|
||||||
|
|
||||||
def start(config_file,
|
if __name__ == "__main__":
|
||||||
url_root="./translator",
|
main()
|
||||||
host="0.0.0.0",
|
|
||||||
port=5000,
|
|
||||||
debug=True):
|
|
||||||
def prefix_route(route_function, prefix='', mask='{0}{1}'):
|
|
||||||
def newroute(route, *args, **kwargs):
|
|
||||||
return route_function(mask.format(prefix, route), *args, **kwargs)
|
|
||||||
return newroute
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
app.route = prefix_route(app.route, url_root)
|
|
||||||
translation_server = TranslationServer()
|
|
||||||
translation_server.start(config_file)
|
|
||||||
|
|
||||||
@app.route('/models', methods=['GET'])
|
|
||||||
def get_models():
|
|
||||||
out = translation_server.list_models()
|
|
||||||
return jsonify(out)
|
|
||||||
|
|
||||||
@app.route('/health', methods=['GET'])
|
|
||||||
def health():
|
|
||||||
out = {}
|
|
||||||
out['status'] = STATUS_OK
|
|
||||||
return jsonify(out)
|
|
||||||
|
|
||||||
@app.route('/clone_model/<int:model_id>', methods=['POST'])
|
|
||||||
def clone_model(model_id):
|
|
||||||
out = {}
|
|
||||||
data = request.get_json(force=True)
|
|
||||||
timeout = -1
|
|
||||||
if 'timeout' in data:
|
|
||||||
timeout = data['timeout']
|
|
||||||
del data['timeout']
|
|
||||||
|
|
||||||
opt = data.get('opt', None)
|
|
||||||
try:
|
|
||||||
model_id, load_time = translation_server.clone_model(
|
|
||||||
model_id, opt, timeout)
|
|
||||||
except ServerModelError as e:
|
|
||||||
out['status'] = STATUS_ERROR
|
|
||||||
out['error'] = str(e)
|
|
||||||
else:
|
|
||||||
out['status'] = STATUS_OK
|
|
||||||
out['model_id'] = model_id
|
|
||||||
out['load_time'] = load_time
|
|
||||||
|
|
||||||
return jsonify(out)
|
|
||||||
|
|
||||||
@app.route('/unload_model/<int:model_id>', methods=['GET'])
|
|
||||||
def unload_model(model_id):
|
|
||||||
out = {"model_id": model_id}
|
|
||||||
|
|
||||||
try:
|
|
||||||
translation_server.unload_model(model_id)
|
|
||||||
out['status'] = STATUS_OK
|
|
||||||
except Exception as e:
|
|
||||||
out['status'] = STATUS_ERROR
|
|
||||||
out['error'] = str(e)
|
|
||||||
|
|
||||||
return jsonify(out)
|
|
||||||
|
|
||||||
@app.route('/translate', methods=['POST'])
|
|
||||||
def translate():
|
|
||||||
inputs = request.get_json(force=True)
|
|
||||||
out = {}
|
|
||||||
try:
|
|
||||||
translation, scores, n_best, times = translation_server.run(inputs)
|
|
||||||
assert len(translation) == len(inputs)
|
|
||||||
assert len(scores) == len(inputs)
|
|
||||||
|
|
||||||
out = [[{"src": inputs[i]['src'], "tgt": translation[i],
|
|
||||||
"n_best": n_best,
|
|
||||||
"pred_score": scores[i]}
|
|
||||||
for i in range(len(translation))]]
|
|
||||||
except ServerModelError as e:
|
|
||||||
out['error'] = str(e)
|
|
||||||
out['status'] = STATUS_ERROR
|
|
||||||
|
|
||||||
return jsonify(out)
|
|
||||||
|
|
||||||
@app.route('/to_cpu/<int:model_id>', methods=['GET'])
|
|
||||||
def to_cpu(model_id):
|
|
||||||
out = {'model_id': model_id}
|
|
||||||
translation_server.models[model_id].to_cpu()
|
|
||||||
|
|
||||||
out['status'] = STATUS_OK
|
|
||||||
return jsonify(out)
|
|
||||||
|
|
||||||
@app.route('/to_gpu/<int:model_id>', methods=['GET'])
|
|
||||||
def to_gpu(model_id):
|
|
||||||
out = {'model_id': model_id}
|
|
||||||
translation_server.models[model_id].to_gpu()
|
|
||||||
|
|
||||||
out['status'] = STATUS_OK
|
|
||||||
return jsonify(out)
|
|
||||||
|
|
||||||
app.run(debug=debug, host=host, port=port, use_reloader=False,
|
|
||||||
threaded=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_parser():
|
|
||||||
parser = configargparse.ArgumentParser(
|
|
||||||
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
|
||||||
description="OpenNMT-py REST Server")
|
|
||||||
parser.add_argument("--ip", type=str, default="0.0.0.0")
|
|
||||||
parser.add_argument("--port", type=int, default="5000")
|
|
||||||
parser.add_argument("--url_root", type=str, default="/translator")
|
|
||||||
parser.add_argument("--debug", "-d", action="store_true")
|
|
||||||
parser.add_argument("--config", "-c", type=str,
|
|
||||||
default="./available_models/conf.json")
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = _get_parser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
start(args.config, url_root=args.url_root, host=args.ip, port=args.port,
|
|
||||||
debug=args.debug)
|
|
||||||
|
|
|
@ -1,11 +1,45 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
from os import path
|
||||||
|
|
||||||
from setuptools import setup
|
this_directory = path.abspath(path.dirname(__file__))
|
||||||
|
with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
|
||||||
|
long_description = f.read()
|
||||||
|
|
||||||
setup(name='OpenNMT-py',
|
setup(
|
||||||
|
name='OpenNMT-py',
|
||||||
description='A python implementation of OpenNMT',
|
description='A python implementation of OpenNMT',
|
||||||
version='0.9.1',
|
long_description=long_description,
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
packages=['onmt', 'onmt.encoders', 'onmt.modules', 'onmt.tests',
|
version='1.1.1',
|
||||||
'onmt.translate', 'onmt.decoders', 'onmt.inputters',
|
packages=find_packages(),
|
||||||
'onmt.models', 'onmt.utils'])
|
project_urls={
|
||||||
|
"Documentation": "http://opennmt.net/OpenNMT-py/",
|
||||||
|
"Forum": "http://forum.opennmt.net/",
|
||||||
|
"Gitter": "https://gitter.im/OpenNMT/OpenNMT-py",
|
||||||
|
"Source": "https://github.com/OpenNMT/OpenNMT-py/"
|
||||||
|
},
|
||||||
|
install_requires=[
|
||||||
|
"six",
|
||||||
|
"tqdm~=4.30.0",
|
||||||
|
"torch>=1.4.0",
|
||||||
|
"torchtext==0.4.0",
|
||||||
|
"future",
|
||||||
|
"configargparse",
|
||||||
|
"tensorboard>=1.14",
|
||||||
|
"flask",
|
||||||
|
"waitress",
|
||||||
|
"pyonmttok==1.*;platform_system=='Linux'",
|
||||||
|
"pyyaml",
|
||||||
|
],
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": [
|
||||||
|
"onmt_server=onmt.bin.server:main",
|
||||||
|
"onmt_train=onmt.bin.train:main",
|
||||||
|
"onmt_translate=onmt.bin.translate:main",
|
||||||
|
"onmt_preprocess=onmt.bin.preprocess:main",
|
||||||
|
"onmt_release_model=onmt.bin.release_model:main",
|
||||||
|
"onmt_average_models=onmt.bin.average_models:main"
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
|
@ -1,45 +1,5 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import argparse
|
from onmt.bin.average_models import main
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def average_models(model_files):
|
|
||||||
vocab = None
|
|
||||||
opt = None
|
|
||||||
avg_model = None
|
|
||||||
avg_generator = None
|
|
||||||
|
|
||||||
for i, model_file in enumerate(model_files):
|
|
||||||
m = torch.load(model_file, map_location='cpu')
|
|
||||||
model_weights = m['model']
|
|
||||||
generator_weights = m['generator']
|
|
||||||
|
|
||||||
if i == 0:
|
|
||||||
vocab, opt = m['vocab'], m['opt']
|
|
||||||
avg_model = model_weights
|
|
||||||
avg_generator = generator_weights
|
|
||||||
else:
|
|
||||||
for (k, v) in avg_model.items():
|
|
||||||
avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1)
|
|
||||||
|
|
||||||
for (k, v) in avg_generator.items():
|
|
||||||
avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1)
|
|
||||||
|
|
||||||
final = {"vocab": vocab, "opt": opt, "optim": None,
|
|
||||||
"generator": avg_generator, "model": avg_model}
|
|
||||||
return final
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="")
|
|
||||||
parser.add_argument("-models", "-m", nargs="+", required=True,
|
|
||||||
help="List of models")
|
|
||||||
parser.add_argument("-output", "-o", required=True,
|
|
||||||
help="Output file")
|
|
||||||
opt = parser.parse_args()
|
|
||||||
|
|
||||||
final = average_models(opt.models)
|
|
||||||
torch.save(final, opt.output)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def read_files_batch(file_list):
|
def read_files_batch(file_list):
|
||||||
|
@ -10,7 +11,6 @@ def read_files_batch(file_list):
|
||||||
fd_list = [] # File descriptor list
|
fd_list = [] # File descriptor list
|
||||||
|
|
||||||
exit = False # Flag used for quitting the program in case of error
|
exit = False # Flag used for quitting the program in case of error
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for filename in file_list:
|
for filename in file_list:
|
||||||
fd_list.append(open(filename))
|
fd_list.append(open(filename))
|
||||||
|
@ -51,7 +51,8 @@ def main():
|
||||||
corresponding to the argument 'side'.""")
|
corresponding to the argument 'side'.""")
|
||||||
parser.add_argument("-file", type=str, nargs="+", required=True)
|
parser.add_argument("-file", type=str, nargs="+", required=True)
|
||||||
parser.add_argument("-out_file", type=str, required=True)
|
parser.add_argument("-out_file", type=str, required=True)
|
||||||
parser.add_argument("-side", type=str)
|
parser.add_argument("-side", choices=['src', 'tgt'], help="""Specifies
|
||||||
|
'src' or 'tgt' side for 'field' file_type.""")
|
||||||
|
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
@ -72,8 +73,16 @@ def main():
|
||||||
reverse=True):
|
reverse=True):
|
||||||
f.write("{0}\n".format(w))
|
f.write("{0}\n".format(w))
|
||||||
else:
|
else:
|
||||||
|
if opt.side not in ['src', 'tgt']:
|
||||||
|
raise ValueError("If using -file_type='field', specifies "
|
||||||
|
"'src' or 'tgt' argument for -side.")
|
||||||
import torch
|
import torch
|
||||||
|
try:
|
||||||
from onmt.inputters.inputter import _old_style_vocab
|
from onmt.inputters.inputter import _old_style_vocab
|
||||||
|
except ImportError:
|
||||||
|
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||||
|
from onmt.inputters.inputter import _old_style_vocab
|
||||||
|
|
||||||
print("Reading input file...")
|
print("Reading input file...")
|
||||||
if not len(opt.file) == 1:
|
if not len(opt.file) == 1:
|
||||||
raise ValueError("If using -file_type='field', only pass one "
|
raise ValueError("If using -file_type='field', only pass one "
|
||||||
|
|
|
@ -137,7 +137,7 @@ def main():
|
||||||
logger.info("\t* enc: %d match, %d missing, (%.2f%%)"
|
logger.info("\t* enc: %d match, %d missing, (%.2f%%)"
|
||||||
% calc_vocab_load_stats(enc_vocab, src_vectors))
|
% calc_vocab_load_stats(enc_vocab, src_vectors))
|
||||||
logger.info("\t* dec: %d match, %d missing, (%.2f%%)"
|
logger.info("\t* dec: %d match, %d missing, (%.2f%%)"
|
||||||
% calc_vocab_load_stats(dec_vocab, src_vectors))
|
% calc_vocab_load_stats(dec_vocab, tgt_vectors))
|
||||||
|
|
||||||
# Write to file
|
# Write to file
|
||||||
enc_output_file = opt.output_file + ".enc.pt"
|
enc_output_file = opt.output_file + ".enc.pt"
|
||||||
|
|
|
@ -1,16 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import argparse
|
from onmt.bin.release_model import main
|
||||||
import torch
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
main()
|
||||||
description="Removes the optim data of PyTorch models")
|
|
||||||
parser.add_argument("--model", "-m",
|
|
||||||
help="The model filename (*.pt)", required=True)
|
|
||||||
parser.add_argument("--output", "-o",
|
|
||||||
help="The output filename (*.pt)", required=True)
|
|
||||||
opt = parser.parse_args()
|
|
||||||
|
|
||||||
model = torch.load(opt.model)
|
|
||||||
model['optim'] = None
|
|
||||||
torch.save(model, opt.output)
|
|
||||||
|
|
198
opennmt/train.py
198
opennmt/train.py
|
@ -1,200 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
"""Train models."""
|
from onmt.bin.train import main
|
||||||
import os
|
|
||||||
import signal
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import onmt.opts as opts
|
|
||||||
import onmt.utils.distributed
|
|
||||||
|
|
||||||
from onmt.utils.misc import set_random_seed
|
|
||||||
from onmt.utils.logging import init_logger, logger
|
|
||||||
from onmt.train_single import main as single_main
|
|
||||||
from onmt.utils.parse import ArgumentParser
|
|
||||||
from onmt.inputters.inputter import build_dataset_iter, \
|
|
||||||
load_old_vocab, old_style_vocab, build_dataset_iter_multiple
|
|
||||||
|
|
||||||
from itertools import cycle
|
|
||||||
|
|
||||||
|
|
||||||
def main(opt):
|
|
||||||
ArgumentParser.validate_train_opts(opt)
|
|
||||||
ArgumentParser.update_model_opts(opt)
|
|
||||||
ArgumentParser.validate_model_opts(opt)
|
|
||||||
|
|
||||||
# Load checkpoint if we resume from a previous training.
|
|
||||||
if opt.train_from:
|
|
||||||
logger.info('Loading checkpoint from %s' % opt.train_from)
|
|
||||||
checkpoint = torch.load(opt.train_from,
|
|
||||||
map_location=lambda storage, loc: storage)
|
|
||||||
logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
|
|
||||||
vocab = checkpoint['vocab']
|
|
||||||
else:
|
|
||||||
vocab = torch.load(opt.data + '.vocab.pt')
|
|
||||||
|
|
||||||
# check for code where vocab is saved instead of fields
|
|
||||||
# (in the future this will be done in a smarter way)
|
|
||||||
if old_style_vocab(vocab):
|
|
||||||
fields = load_old_vocab(
|
|
||||||
vocab, opt.model_type, dynamic_dict=opt.copy_attn)
|
|
||||||
else:
|
|
||||||
fields = vocab
|
|
||||||
|
|
||||||
if len(opt.data_ids) > 1:
|
|
||||||
train_shards = []
|
|
||||||
for train_id in opt.data_ids:
|
|
||||||
shard_base = "train_" + train_id
|
|
||||||
train_shards.append(shard_base)
|
|
||||||
train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
|
|
||||||
else:
|
|
||||||
if opt.data_ids[0] is not None:
|
|
||||||
shard_base = "train_" + opt.data_ids[0]
|
|
||||||
else:
|
|
||||||
shard_base = "train"
|
|
||||||
train_iter = build_dataset_iter(shard_base, fields, opt)
|
|
||||||
|
|
||||||
nb_gpu = len(opt.gpu_ranks)
|
|
||||||
|
|
||||||
if opt.world_size > 1:
|
|
||||||
queues = []
|
|
||||||
mp = torch.multiprocessing.get_context('spawn')
|
|
||||||
semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
|
|
||||||
# Create a thread to listen for errors in the child processes.
|
|
||||||
error_queue = mp.SimpleQueue()
|
|
||||||
error_handler = ErrorHandler(error_queue)
|
|
||||||
# Train with multiprocessing.
|
|
||||||
procs = []
|
|
||||||
for device_id in range(nb_gpu):
|
|
||||||
q = mp.Queue(opt.queue_size)
|
|
||||||
queues += [q]
|
|
||||||
procs.append(mp.Process(target=run, args=(
|
|
||||||
opt, device_id, error_queue, q, semaphore), daemon=True))
|
|
||||||
procs[device_id].start()
|
|
||||||
logger.info(" Starting process pid: %d " % procs[device_id].pid)
|
|
||||||
error_handler.add_child(procs[device_id].pid)
|
|
||||||
producer = mp.Process(target=batch_producer,
|
|
||||||
args=(train_iter, queues, semaphore, opt,),
|
|
||||||
daemon=True)
|
|
||||||
producer.start()
|
|
||||||
error_handler.add_child(producer.pid)
|
|
||||||
|
|
||||||
for p in procs:
|
|
||||||
p.join()
|
|
||||||
producer.terminate()
|
|
||||||
|
|
||||||
elif nb_gpu == 1: # case 1 GPU only
|
|
||||||
single_main(opt, 0)
|
|
||||||
else: # case only CPU
|
|
||||||
single_main(opt, -1)
|
|
||||||
|
|
||||||
|
|
||||||
def batch_producer(generator_to_serve, queues, semaphore, opt):
|
|
||||||
init_logger(opt.log_file)
|
|
||||||
set_random_seed(opt.seed, False)
|
|
||||||
# generator_to_serve = iter(generator_to_serve)
|
|
||||||
|
|
||||||
def pred(x):
|
|
||||||
"""
|
|
||||||
Filters batches that belong only
|
|
||||||
to gpu_ranks of current node
|
|
||||||
"""
|
|
||||||
for rank in opt.gpu_ranks:
|
|
||||||
if x[0] % opt.world_size == rank:
|
|
||||||
return True
|
|
||||||
|
|
||||||
generator_to_serve = filter(
|
|
||||||
pred, enumerate(generator_to_serve))
|
|
||||||
|
|
||||||
def next_batch(device_id):
|
|
||||||
new_batch = next(generator_to_serve)
|
|
||||||
semaphore.acquire()
|
|
||||||
return new_batch[1]
|
|
||||||
|
|
||||||
b = next_batch(0)
|
|
||||||
|
|
||||||
for device_id, q in cycle(enumerate(queues)):
|
|
||||||
b.dataset = None
|
|
||||||
if isinstance(b.src, tuple):
|
|
||||||
b.src = tuple([_.to(torch.device(device_id))
|
|
||||||
for _ in b.src])
|
|
||||||
else:
|
|
||||||
b.src = b.src.to(torch.device(device_id))
|
|
||||||
b.tgt = b.tgt.to(torch.device(device_id))
|
|
||||||
b.indices = b.indices.to(torch.device(device_id))
|
|
||||||
b.alignment = b.alignment.to(torch.device(device_id)) \
|
|
||||||
if hasattr(b, 'alignment') else None
|
|
||||||
b.src_map = b.src_map.to(torch.device(device_id)) \
|
|
||||||
if hasattr(b, 'src_map') else None
|
|
||||||
|
|
||||||
# hack to dodge unpicklable `dict_keys`
|
|
||||||
b.fields = list(b.fields)
|
|
||||||
q.put(b)
|
|
||||||
b = next_batch(device_id)
|
|
||||||
|
|
||||||
|
|
||||||
def run(opt, device_id, error_queue, batch_queue, semaphore):
|
|
||||||
""" run process """
|
|
||||||
try:
|
|
||||||
gpu_rank = onmt.utils.distributed.multi_init(opt, device_id)
|
|
||||||
if gpu_rank != opt.gpu_ranks[device_id]:
|
|
||||||
raise AssertionError("An error occurred in \
|
|
||||||
Distributed initialization")
|
|
||||||
single_main(opt, device_id, batch_queue, semaphore)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass # killed by parent, do nothing
|
|
||||||
except Exception:
|
|
||||||
# propagate exception to parent process, keeping original traceback
|
|
||||||
import traceback
|
|
||||||
error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc()))
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorHandler(object):
|
|
||||||
"""A class that listens for exceptions in children processes and propagates
|
|
||||||
the tracebacks to the parent process."""
|
|
||||||
|
|
||||||
def __init__(self, error_queue):
|
|
||||||
""" init error handler """
|
|
||||||
import signal
|
|
||||||
import threading
|
|
||||||
self.error_queue = error_queue
|
|
||||||
self.children_pids = []
|
|
||||||
self.error_thread = threading.Thread(
|
|
||||||
target=self.error_listener, daemon=True)
|
|
||||||
self.error_thread.start()
|
|
||||||
signal.signal(signal.SIGUSR1, self.signal_handler)
|
|
||||||
|
|
||||||
def add_child(self, pid):
|
|
||||||
""" error handler """
|
|
||||||
self.children_pids.append(pid)
|
|
||||||
|
|
||||||
def error_listener(self):
|
|
||||||
""" error listener """
|
|
||||||
(rank, original_trace) = self.error_queue.get()
|
|
||||||
self.error_queue.put((rank, original_trace))
|
|
||||||
os.kill(os.getpid(), signal.SIGUSR1)
|
|
||||||
|
|
||||||
def signal_handler(self, signalnum, stackframe):
|
|
||||||
""" signal handler """
|
|
||||||
for pid in self.children_pids:
|
|
||||||
os.kill(pid, signal.SIGINT) # kill children processes
|
|
||||||
(rank, original_trace) = self.error_queue.get()
|
|
||||||
msg = """\n\n-- Tracebacks above this line can probably
|
|
||||||
be ignored --\n\n"""
|
|
||||||
msg += original_trace
|
|
||||||
raise Exception(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_parser():
|
|
||||||
parser = ArgumentParser(description='train.py')
|
|
||||||
|
|
||||||
opts.config_opts(parser)
|
|
||||||
opts.model_opts(parser)
|
|
||||||
opts.train_opts(parser)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = _get_parser()
|
main()
|
||||||
|
|
||||||
opt = parser.parse_args()
|
|
||||||
main(opt)
|
|
||||||
|
|
|
@ -1,52 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: utf-8 -*-
|
from onmt.bin.translate import main
|
||||||
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
from itertools import repeat
|
|
||||||
|
|
||||||
from onmt.utils.logging import init_logger
|
|
||||||
from onmt.utils.misc import split_corpus
|
|
||||||
from onmt.translate.translator import build_translator
|
|
||||||
|
|
||||||
import onmt.opts as opts
|
|
||||||
from onmt.utils.parse import ArgumentParser
|
|
||||||
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
|
|
||||||
def main(opt):
|
|
||||||
with open('../opt_data', 'wb') as f:
|
|
||||||
pickle.dump(opt, f)
|
|
||||||
ArgumentParser.validate_translate_opts(opt)
|
|
||||||
logger = init_logger(opt.log_file)
|
|
||||||
|
|
||||||
translator = build_translator(opt, report_score=True)
|
|
||||||
src_shards = split_corpus(opt.src, opt.shard_size)
|
|
||||||
tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
|
|
||||||
if opt.tgt is not None else repeat(None)
|
|
||||||
shard_pairs = zip(src_shards, tgt_shards)
|
|
||||||
|
|
||||||
for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
|
|
||||||
logger.info("Translating shard %d." % i)
|
|
||||||
translator.translate(
|
|
||||||
src=src_shard,
|
|
||||||
tgt=tgt_shard,
|
|
||||||
src_dir=opt.src_dir,
|
|
||||||
batch_size=opt.batch_size,
|
|
||||||
attn_debug=opt.attn_debug
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_parser():
|
|
||||||
parser = ArgumentParser(description='translate.py')
|
|
||||||
|
|
||||||
opts.config_opts(parser)
|
|
||||||
opts.translate_opts(parser)
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = _get_parser()
|
main()
|
||||||
|
|
||||||
opt = parser.parse_args()
|
|
||||||
main(opt)
|
|
||||||
|
|
Двоичные данные
opt_data
Двоичные данные
opt_data
Двоичный файл не отображается.
|
@ -1,12 +1,6 @@
|
||||||
Django==3.0.3
|
Django==3.0.3
|
||||||
numpy==1.16.3
|
numpy
|
||||||
six
|
|
||||||
tqdm==4.30.*
|
|
||||||
torch==1.1 -f https://download.pytorch.org/whl/torch_stable.html
|
|
||||||
git+https://github.com/pytorch/text.git@master#wheel=torchtext
|
|
||||||
future
|
|
||||||
channels
|
channels
|
||||||
configargparse
|
|
||||||
editdistance
|
editdistance
|
||||||
indic_transliteration
|
indic_transliteration
|
||||||
jsonfield
|
jsonfield
|
Загрузка…
Ссылка в новой задаче