Update OpenNMT-py to c20dbeac02688918607637f5f30ec73c0f17d817

This commit is contained in:
Sebastin Santy 2020-04-08 02:49:16 +05:30
Родитель ea5d36d9e7
Коммит 70ef77eede
99 изменённых файлов: 4358 добавлений и 2859 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -1,6 +1,5 @@
model/ model/
*.pyc *.pyc
db.sqlite3 db.sqlite3
opennmt/.git
pred.txt pred.txt
conf.py conf.py

Просмотреть файл

@ -4,5 +4,6 @@ RUN mkdir /inmt
WORKDIR /inmt WORKDIR /inmt
COPY requirements.txt /inmt/ COPY requirements.txt /inmt/
RUN pip install -r requirements.txt RUN pip install -r requirements.txt
RUN cd OpenNMT-py
RUN python setup.py install
COPY . /inmt/ COPY . /inmt/

Просмотреть файл

@ -3,25 +3,32 @@ Interactive Machine Translation app uses Django and jQuery as its tech stack. Pl
# Installation Instructions # Installation Instructions
Make a new model folder using `mkdir model` where the models need to be placed. Models can be downloaded from [https://microsoft-my.sharepoint.com/:f:/p/t-sesan/Evsn3riZxktJuterr5A09lABTVjhaL_NoH430IMgkzws9Q?e=VXzX5T]. 1. Clone INMT and prepare MT models:
```
git clone https://github.com/microsoft/inmt
```
2. Make a new model folder using `mkdir model` where the models need to be placed. Models can be downloaded from [here](https://microsoft-my.sharepoint.com/:f:/p/t-sesan/Evsn3riZxktJuterr5A09lABTVjhaL_NoH430IMgkzws9Q?e=VXzX5T). These contain English to Hindi translation models in both directions. If you want to train your own models, refer to [Training MT Models](#training-mt-models)
3. Rest of the installation can be carried out either bare or using docker. Docker is preferable for its ease of installation.
## Docker Installation ## Docker Installation
Assuming you have docker setup in your system, simply run `docker-compose up -d`. This application requires atleast 4GB of memory in order to run. Allot your docker memory accordingly. Assuming you have docker setup in your system, simply run `docker-compose up -d`. This application requires atleast 4GB of memory in order to run. Allot your docker memory accordingly.
## Bare Installation ## Bare Installation
1. Install dependencies using - `python -m pip install -r requirements.txt`. Be sure to check your python version. This tool is compatible with Python3. 1. Install dependencies using - `python -m pip install -r requirements.txt`. Be sure to check your python version. This tool is compatible with Python3.
2. Run the migrations and start the server - `python manage.py makemigrations && python manage.py migrate && python manage.py runserver` 2. Install OpenNMT dependences using - `cd opennmt & python setup.py install & cd -`
3. The server opens on port 8000 by default. Open `localhost:8000/simple` for the simple interface. 3. Run the migrations and start the server - `python manage.py makemigrations && python manage.py migrate && python manage.py runserver`
4. The server opens on port 8000 by default. Open `localhost:8000/simple` for the simple interface.
## Training MT Models ## Training MT Models
OpenNMT is used as the translation engine to power INMT. In order to train your own models, you need parallel sentences in your desired language. The basic instructions are listed as follows: OpenNMT is used as the translation engine to power INMT. In order to train your own models, you need parallel sentences in your desired language. The basic instructions are listed as follows:
1. Go to opennmt folder: `cd opennmt` 1. Go to opennmt folder: `cd opennmt`
2. Preprocess parallel (src & tgt) sentences: 2. Preprocess parallel (src & tgt) sentences:
```python preprocess.py -train_src <src_lang_train> -train_tgt <tgt_lang_train> -valid_src <src_lang_valid> -valid_tgt <tgt_lang_valid> -save_data <processed_data>``` ```onmt_preprocess -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo```
3. Train your model (with GPUs): 3. Train your model (with GPUs):
```python train.py -data <processed_data> -save_model <model_name> -gpu_ranks 0``` ```onmt_train -data data/demo -save_model demo-model```
For more information on the training process, refer to [OpenNMT docs](https://opennmt.net/OpenNMT-py/quickstart.html). For advanced instructions on the training process, refer to [OpenNMT docs](https://opennmt.net/OpenNMT-py/quickstart.html).
# Contributing # Contributing

Просмотреть файл

@ -1,8 +1,7 @@
dist: xenial dist: xenial
language: python language: python
python: python:
- "2.7" - "3.6"
- "3.5"
git: git:
depth: false depth: false
addons: addons:
@ -14,13 +13,30 @@ addons:
- sox - sox
before_install: before_install:
# Install CPU version of PyTorch. # Install CPU version of PyTorch.
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp27-cp27mu-linux_x86_64.whl; fi - if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install torch==1.4.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html; fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp35-cp35m-linux_x86_64.whl; fi - pip install --upgrade setuptools
- pip install -r requirements.txt
- pip install -r requirements.opt.txt - pip install -r requirements.opt.txt
install:
- python setup.py install - python setup.py install
env:
global:
# Doctr deploy key for OpenNMT/OpenNMT-py
- secure: "gL0Soefo1cQgAqwiHUrlNyZd/+SI1eJAAjLD3BEDQWXW160eXyjQAAujGgJoCirjOM7cPHVwLzwmK3S7Y3PVM3JOZguOX5Yl4uxMh/mhiEM+RG77SZyv4OGoLFsEQ8RTvIdYdtP6AwyjlkRDXvZql88TqFNYjpXDu8NG+JwEfiIoGIDYxxZ5SlbrZN0IqmQSZ4/CsV6VQiuq99Jn5kqi4MnUZBTcmhqjaztCP1omvsMRdbrG2IVhDKQOCDIO0kaPJrMy2SGzP4GV7ar52bdBtpeP3Xbm6ZOuhDNfds7M/OMHp1wGdl7XwKtolw9MeXhnGBC4gcrqhhMfcQ6XtfVLMLnsB09Ezl3FXX5zWgTB5Pm0X6TgnGrMA25MAdVqKGJpfqZxOKTh4EMb04b6OXrVbxZ88mp+V0NopuxwlTPD8PMfYLWlTe9chh1BnT0iQlLqeA4Hv3+NdpiFb4aq3V3cWTTgMqOoWSGq4t318pqIZ3qbBXBq12DLFgO5n6+M6ZrdxbDUGQvgh8nAiZcIEdodKJ4ABHi1SNCeWOzCoedUdegcbjShHfkMVmNKrncB18aRWwQ3GQJ5qdkjgJmC++uZmkS6+GPM8UmmAy1ZIkRW0aWiitjG6teqtvUHOofNd/TCxX4bhnxAj+mtVIrARCE/ci8topJ6uG4wVJ1TrIkUlAY="
jobs:
include:
- name: "Flake8 Lint Check"
env: LINT_CHECK
install: pip install flake8 pep8-naming==0.7.0
script: flake8
- name: "Build Docs"
install:
- pip install doctr
script:
- pip install -r docs/requirements.txt
- set -e
- cd docs/ && make html && cd ..
- doctr deploy --built-docs docs/build/html/ .
- name: "Unit tests"
# Please also add tests to `test/pull_request_chk.sh`. # Please also add tests to `test/pull_request_chk.sh`.
script: script:
- wget -O /tmp/im2text.tgz http://lstm.seas.harvard.edu/latex/im2text_small.tgz; tar zxf /tmp/im2text.tgz -C /tmp/; head /tmp/im2text/src-train.txt > /tmp/im2text/src-train-head.txt; head /tmp/im2text/tgt-train.txt > /tmp/im2text/tgt-train-head.txt; head /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt - wget -O /tmp/im2text.tgz http://lstm.seas.harvard.edu/latex/im2text_small.tgz; tar zxf /tmp/im2text.tgz -C /tmp/; head /tmp/im2text/src-train.txt > /tmp/im2text/src-train-head.txt; head /tmp/im2text/tgt-train.txt > /tmp/im2text/tgt-train-head.txt; head /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt
@ -47,6 +63,9 @@ script:
# test nmt preprocessing w/ sharding and training w/copy # test nmt preprocessing w/ sharding and training w/copy
- head -50 data/src-val.txt > /tmp/src-val.txt; head -50 data/tgt-val.txt > /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -shard_size 25 -dynamic_dict -save_data /tmp/q -src_vocab_size 1000 -tgt_vocab_size 1000; python train.py -data /tmp/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -copy_attn -train_steps 10 -pool_factor 10 && rm -rf /tmp/q*.pt - head -50 data/src-val.txt > /tmp/src-val.txt; head -50 data/tgt-val.txt > /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -shard_size 25 -dynamic_dict -save_data /tmp/q -src_vocab_size 1000 -tgt_vocab_size 1000; python train.py -data /tmp/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -copy_attn -train_steps 10 -pool_factor 10 && rm -rf /tmp/q*.pt
# test Graph neural network preprocessing and training
- cp data/ggnnsrc.txt /tmp/src-val.txt; cp data/ggnntgt.txt /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -src_seq_length 1000 -tgt_seq_length 30 -src_vocab data/ggnnsrcvocab.txt -tgt_vocab data/ggnntgtvocab.txt -dynamic_dict -save_data /tmp/q ; python train.py -data /tmp/q -encoder_type ggnn -layers 2 -decoder_type rnn -rnn_size 256 -learning_rate 0.1 -learning_rate_decay 0.8 -global_attention general -batch_size 32 -word_vec_size 256 -bridge -train_steps 10 -src_vocab data/ggnnsrcvocab.txt -n_edge_types 9 -state_dim 256 -n_steps 10 -n_node 64 && rm -rf /tmp/q*.pt
# test im2text preprocessing and training # test im2text preprocessing and training
- head -50 /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head -50 /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt; python preprocess.py -data_type img -src_dir /tmp/im2text/images -train_src /tmp/im2text/src-val-head.txt -train_tgt /tmp/im2text/tgt-val-head.txt -valid_src /tmp/im2text/src-val-head.txt -valid_tgt /tmp/im2text/tgt-val-head.txt -save_data /tmp/im2text/q -tgt_seq_length 100; python train.py -model_type img -data /tmp/im2text/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 -pool_factor 10 && rm -rf /tmp/im2text/q*.pt - head -50 /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head -50 /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt; python preprocess.py -data_type img -src_dir /tmp/im2text/images -train_src /tmp/im2text/src-val-head.txt -train_tgt /tmp/im2text/tgt-val-head.txt -valid_src /tmp/im2text/src-val-head.txt -valid_tgt /tmp/im2text/tgt-val-head.txt -save_data /tmp/im2text/q -tgt_seq_length 100; python train.py -model_type img -data /tmp/im2text/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 -pool_factor 10 && rm -rf /tmp/im2text/q*.pt
# test speech2text preprocessing and training # test speech2text preprocessing and training
@ -57,24 +76,3 @@ script:
- python translate.py -model onmt/tests/test_model2.pt -src data/morph/src.valid -verbose -batch_size 10 -beam_size 1 -seed 1 -random_sampling_topk "-1" -random_sampling_temp 0.0001 -tgt data/morph/tgt.valid -out /tmp/trans; diff data/morph/tgt.valid /tmp/trans - python translate.py -model onmt/tests/test_model2.pt -src data/morph/src.valid -verbose -batch_size 10 -beam_size 1 -seed 1 -random_sampling_topk "-1" -random_sampling_temp 0.0001 -tgt data/morph/tgt.valid -out /tmp/trans; diff data/morph/tgt.valid /tmp/trans
# test tool # test tool
- PYTHONPATH=$PYTHONPATH:. python tools/extract_embeddings.py -model onmt/tests/test_model.pt - PYTHONPATH=$PYTHONPATH:. python tools/extract_embeddings.py -model onmt/tests/test_model.pt
env:
global:
# Doctr deploy key for OpenNMT/OpenNMT-py
- secure: "gL0Soefo1cQgAqwiHUrlNyZd/+SI1eJAAjLD3BEDQWXW160eXyjQAAujGgJoCirjOM7cPHVwLzwmK3S7Y3PVM3JOZguOX5Yl4uxMh/mhiEM+RG77SZyv4OGoLFsEQ8RTvIdYdtP6AwyjlkRDXvZql88TqFNYjpXDu8NG+JwEfiIoGIDYxxZ5SlbrZN0IqmQSZ4/CsV6VQiuq99Jn5kqi4MnUZBTcmhqjaztCP1omvsMRdbrG2IVhDKQOCDIO0kaPJrMy2SGzP4GV7ar52bdBtpeP3Xbm6ZOuhDNfds7M/OMHp1wGdl7XwKtolw9MeXhnGBC4gcrqhhMfcQ6XtfVLMLnsB09Ezl3FXX5zWgTB5Pm0X6TgnGrMA25MAdVqKGJpfqZxOKTh4EMb04b6OXrVbxZ88mp+V0NopuxwlTPD8PMfYLWlTe9chh1BnT0iQlLqeA4Hv3+NdpiFb4aq3V3cWTTgMqOoWSGq4t318pqIZ3qbBXBq12DLFgO5n6+M6ZrdxbDUGQvgh8nAiZcIEdodKJ4ABHi1SNCeWOzCoedUdegcbjShHfkMVmNKrncB18aRWwQ3GQJ5qdkjgJmC++uZmkS6+GPM8UmmAy1ZIkRW0aWiitjG6teqtvUHOofNd/TCxX4bhnxAj+mtVIrARCE/ci8topJ6uG4wVJ1TrIkUlAY="
matrix:
include:
- env: LINT_CHECK
python: "2.7"
install: pip install flake8 pep8-naming==0.7.0
script: flake8
- python: "3.5"
install:
- python setup.py install
- pip install doctr
script:
- pip install -r docs/requirements.txt
- set -e
- cd docs/ && make html && cd ..
- doctr deploy --built-docs docs/build/html/ .

Просмотреть файл

@ -3,7 +3,60 @@
## [Unreleased] ## [Unreleased]
## [1.1.1](https://github.com/OpenNMT/OpenNMT-py/tree/1.1.1) (2020-03-20)
### Fixes and improvements ### Fixes and improvements
* Fix backcompatibility when no 'corpus_id' field (c313c28)
## [1.1.0](https://github.com/OpenNMT/OpenNMT-py/tree/1.1.0) (2020-03-19)
### New features
* Support CTranslate2 models in REST server (91d5d57)
* Extend support for custom preprocessing/postprocessing function in REST server by using return dictionaries (d14613d, 9619ac3, 92a7ba5)
* Experimental: BART-like source noising (5940dcf)
### Fixes and improvements
* Add options to CTranslate2 release (e442f3f)
* Fix dataset shard order (458fc48)
* Rotate only the server logs, not training (189583a)
* Fix alignment error with empty prediction (91287eb)
## [1.0.2](https://github.com/OpenNMT/OpenNMT-py/tree/1.0.2) (2020-03-05)
### Fixes and improvements
* Enable CTranslate2 conversion of Transformers with relative position (db11135)
* Adapt `-replace_unk` to use with learned alignments if they exist (7625b53)
## [1.0.1](https://github.com/OpenNMT/OpenNMT-py/tree/1.0.1) (2020-02-17)
### Fixes and improvements
* Ctranslate2 conversion handled in release script (1b50e0c)
* Use `attention_dropout` properly in MHA (f5c9cd4)
* Update apex FP16_Optimizer path (d3e2268)
* Some REST server optimizations
* Fix and add some docs
## [1.0.0](https://github.com/OpenNMT/OpenNMT-py/tree/1.0.0) (2019-10-01)
### New features
* Implementation of "Jointly Learning to Align & Translate with Transformer" (@Zenglinxiao)
### Fixes and improvements
* Add nbest support to REST server (@Zenglinxiao)
* Merge greedy and beam search codepaths (@Zenglinxiao)
* Fix "block ngram repeats" (@KaijuML, @pltrdy)
* Small fixes, some more docs
## [1.0.0.rc2](https://github.com/OpenNMT/OpenNMT-py/tree/1.0.0.rc1) (2019-10-01)
* Fix Apex / FP16 training (Apex new API is buggy)
* Multithread preprocessing way faster (Thanks @francoishernandez)
* Pip Installation v1.0.0.rc1 (thanks @pltrdy)
## [0.9.2](https://github.com/OpenNMT/OpenNMT-py/tree/0.9.2) (2019-09-04)
* Switch to Pytorch 1.2
* Pre/post processing on the translation server
* option to remove the FFN layer in AAN + AAN optimization (faster)
* Coverage loss (per Abisee paper 2017) implementation
* Video Captioning task: Thanks Dylan Flaute!
* Token batch at inference
* Small fixes and add-ons
## [0.9.1](https://github.com/OpenNMT/OpenNMT-py/tree/0.9.1) (2019-06-13) ## [0.9.1](https://github.com/OpenNMT/OpenNMT-py/tree/0.9.1) (2019-06-13)
* New mechanism for MultiGPU training "1 batch producer / multi batch consumers" * New mechanism for MultiGPU training "1 batch producer / multi batch consumers"

Просмотреть файл

@ -1,2 +0,0 @@
FROM pytorch/pytorch:latest
RUN git clone https://github.com/OpenNMT/OpenNMT-py.git && cd OpenNMT-py && pip install -r requirements.txt && python setup.py install

Просмотреть файл

@ -3,7 +3,7 @@
[![Build Status](https://travis-ci.org/OpenNMT/OpenNMT-py.svg?branch=master)](https://travis-ci.org/OpenNMT/OpenNMT-py) [![Build Status](https://travis-ci.org/OpenNMT/OpenNMT-py.svg?branch=master)](https://travis-ci.org/OpenNMT/OpenNMT-py)
[![Run on FH](https://img.shields.io/badge/Run%20on-FloydHub-blue.svg)](https://floydhub.com/run?template=https://github.com/OpenNMT/OpenNMT-py) [![Run on FH](https://img.shields.io/badge/Run%20on-FloydHub-blue.svg)](https://floydhub.com/run?template=https://github.com/OpenNMT/OpenNMT-py)
This is a [Pytorch](https://github.com/pytorch/pytorch) This is a [PyTorch](https://github.com/pytorch/pytorch)
port of [OpenNMT](https://github.com/OpenNMT/OpenNMT), port of [OpenNMT](https://github.com/OpenNMT/OpenNMT),
an open-source (MIT) neural machine translation system. It is designed to be research friendly to try out new ideas in translation, summary, image-to-text, morphology, and many other domains. Some companies have proven the code to be production ready. an open-source (MIT) neural machine translation system. It is designed to be research friendly to try out new ideas in translation, summary, image-to-text, morphology, and many other domains. Some companies have proven the code to be production ready.
@ -28,32 +28,43 @@ Table of Contents
## Requirements ## Requirements
All dependencies can be installed via: Install `OpenNMT-py` from `pip`:
```bash ```bash
pip install -r requirements.txt pip install OpenNMT-py
``` ```
NOTE: If you have MemoryError in the install try to use: or from the sources:
```bash ```bash
pip install -r requirements.txt --no-cache-dir git clone https://github.com/OpenNMT/OpenNMT-py.git
cd OpenNMT-py
python setup.py install
``` ```
Note that we currently only support PyTorch 1.1 (should work with 1.0)
Note: If you have MemoryError in the install try to use `pip` with `--no-cache-dir`.
*(Optional)* some advanced features (e.g. working audio, image or pretrained models) requires extra packages, you can install it with:
```bash
pip install -r requirements.opt.txt
```
Note:
- some features require Python 3.5 and after (eg: Distributed multigpu, entmax)
- we currently only support PyTorch 1.4
## Features ## Features
- [data preprocessing](http://opennmt.net/OpenNMT-py/options/preprocess.html) - [Seq2Seq models (encoder-decoder) with multiple RNN cells (lstm/gru) and attention (dotprod/mlp) types](http://opennmt.net/OpenNMT-py/options/train.html#model-encoder-decoder)
- [Inference (translation) with batching and beam search](http://opennmt.net/OpenNMT-py/options/translate.html) - [Transformer models](http://opennmt.net/OpenNMT-py/FAQ.html#how-do-i-use-the-transformer-model)
- [Multiple source and target RNN (lstm/gru) types and attention (dotprod/mlp) types](http://opennmt.net/OpenNMT-py/options/train.html#model-encoder-decoder)
- [TensorBoard](http://opennmt.net/OpenNMT-py/options/train.html#logging)
- [Source word features](http://opennmt.net/OpenNMT-py/options/train.html#model-embeddings)
- [Pretrained Embeddings](http://opennmt.net/OpenNMT-py/FAQ.html#how-do-i-use-pretrained-embeddings-e-g-glove)
- [Copy and Coverage Attention](http://opennmt.net/OpenNMT-py/options/train.html#model-attention) - [Copy and Coverage Attention](http://opennmt.net/OpenNMT-py/options/train.html#model-attention)
- [Pretrained Embeddings](http://opennmt.net/OpenNMT-py/FAQ.html#how-do-i-use-pretrained-embeddings-e-g-glove)
- [Source word features](http://opennmt.net/OpenNMT-py/options/train.html#model-embeddings)
- [Image-to-text processing](http://opennmt.net/OpenNMT-py/im2text.html) - [Image-to-text processing](http://opennmt.net/OpenNMT-py/im2text.html)
- [Speech-to-text processing](http://opennmt.net/OpenNMT-py/speech2text.html) - [Speech-to-text processing](http://opennmt.net/OpenNMT-py/speech2text.html)
- ["Attention is all you need"](http://opennmt.net/OpenNMT-py/FAQ.html#how-do-i-use-the-transformer-model) - [TensorBoard logging](http://opennmt.net/OpenNMT-py/options/train.html#logging)
- [Multi-GPU](http://opennmt.net/OpenNMT-py/FAQ.html##do-you-support-multi-gpu) - [Multi-GPU training](http://opennmt.net/OpenNMT-py/FAQ.html##do-you-support-multi-gpu)
- [Data preprocessing](http://opennmt.net/OpenNMT-py/options/preprocess.html)
- [Inference (translation) with batching and beam search](http://opennmt.net/OpenNMT-py/options/translate.html)
- Inference time loss functions. - Inference time loss functions.
- [Conv2Conv convolution model] - [Conv2Conv convolution model]
- SRU "RNNs faster than CNN" paper - SRU "RNNs faster than CNN" paper
@ -67,7 +78,7 @@ Note that we currently only support PyTorch 1.1 (should work with 1.0)
### Step 1: Preprocess the data ### Step 1: Preprocess the data
```bash ```bash
python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo onmt_preprocess -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo
``` ```
We will be working with some example data in `data/` folder. We will be working with some example data in `data/` folder.
@ -94,21 +105,21 @@ Internally the system never touches the words themselves, but uses these indices
### Step 2: Train the model ### Step 2: Train the model
```bash ```bash
python train.py -data data/demo -save_model demo-model onmt_train -data data/demo -save_model demo-model
``` ```
The main train command is quite simple. Minimally it takes a data file The main train command is quite simple. Minimally it takes a data file
and a save file. This will run the default model, which consists of a and a save file. This will run the default model, which consists of a
2-layer LSTM with 500 hidden units on both the encoder/decoder. 2-layer LSTM with 500 hidden units on both the encoder/decoder.
If you want to train on GPU, you need to set, as an example: If you want to train on GPU, you need to set, as an example:
CUDA_VISIBLE_DEVICES=1,3 `CUDA_VISIBLE_DEVICES=1,3`
`-world_size 2 -gpu_ranks 0 1` to use (say) GPU 1 and 3 on this node only. `-world_size 2 -gpu_ranks 0 1` to use (say) GPU 1 and 3 on this node only.
To know more about distributed training on single or multi nodes, read the FAQ section. To know more about distributed training on single or multi nodes, read the FAQ section.
### Step 3: Translate ### Step 3: Translate
```bash ```bash
python translate.py -model demo-model_acc_XX.XX_ppl_XXX.XX_eX.pt -src data/src-test.txt -output pred.txt -replace_unk -verbose onmt_translate -model demo-model_acc_XX.XX_ppl_XXX.XX_eX.pt -src data/src-test.txt -output pred.txt -replace_unk -verbose
``` ```
Now you have a model which you can use to predict on new data. We do this by running beam search. This will output predictions into `pred.txt`. Now you have a model which you can use to predict on new data. We do this by running beam search. This will output predictions into `pred.txt`.
@ -151,7 +162,7 @@ Major contributors are:
[Dylan Flaute](http://github.com/flauted (University of Dayton) [Dylan Flaute](http://github.com/flauted (University of Dayton)
and more ! and more !
OpentNMT-py belongs to the OpenNMT project along with OpenNMT-Lua and OpenNMT-tf. OpenNMT-py belongs to the OpenNMT project along with OpenNMT-Lua and OpenNMT-tf.
## Citation ## Citation

Просмотреть файл

@ -1,13 +0,0 @@
{
"models_root": "./available_models",
"models":[{
"model": "onmt-hien.pt",
"timeout": -1,
"on_timeout": "unload",
"model_root": "../model/",
"opt": {
"batch_size": 1,
"beam_size": 10
}
}]
}

Просмотреть файл

@ -4,3 +4,4 @@ sphinxcontrib.mermaid
sphinx-rtd-theme sphinx-rtd-theme
recommonmark recommonmark
sphinx-argparse sphinx-argparse
sphinx_markdown_tables

Просмотреть файл

@ -8,7 +8,7 @@ the script is a slightly modified version of ylhsieh's one2.
Usage: Usage:
``` ```shell
embeddings_to_torch.py [-h] [-emb_file_both EMB_FILE_BOTH] embeddings_to_torch.py [-h] [-emb_file_both EMB_FILE_BOTH]
[-emb_file_enc EMB_FILE_ENC] [-emb_file_enc EMB_FILE_ENC]
[-emb_file_dec EMB_FILE_DEC] -output_file [-emb_file_dec EMB_FILE_DEC] -output_file
@ -16,23 +16,23 @@ embeddings_to_torch.py [-h] [-emb_file_both EMB_FILE_BOTH]
[-skip_lines SKIP_LINES] [-skip_lines SKIP_LINES]
[-type {GloVe,word2vec}] [-type {GloVe,word2vec}]
``` ```
Run embeddings_to_torch.py -h for more usagecomplete info. Run embeddings_to_torch.py -h for more usagecomplete info.
Example ### Example
1. Get GloVe files:
1) get GloVe files: ```shell
```
mkdir "glove_dir" mkdir "glove_dir"
wget http://nlp.stanford.edu/data/glove.6B.zip wget http://nlp.stanford.edu/data/glove.6B.zip
unzip glove.6B.zip -d "glove_dir" unzip glove.6B.zip -d "glove_dir"
``` ```
2) prepare data: 2. Prepare data:
``` ```shell
python preprocess.py \ onmt_preprocess \
-train_src data/train.src.txt \ -train_src data/train.src.txt \
-train_tgt data/train.tgt.txt \ -train_tgt data/train.tgt.txt \
-valid_src data/valid.src.txt \ -valid_src data/valid.src.txt \
@ -40,18 +40,18 @@ python preprocess.py \
-save_data data/data -save_data data/data
``` ```
3) prepare embeddings: 3. Prepare embeddings:
``` ```shell
./tools/embeddings_to_torch.py -emb_file_both "glove_dir/glove.6B.100d.txt" \ ./tools/embeddings_to_torch.py -emb_file_both "glove_dir/glove.6B.100d.txt" \
-dict_file "data/data.vocab.pt" \ -dict_file "data/data.vocab.pt" \
-output_file "data/embeddings" -output_file "data/embeddings"
``` ```
4) train using pre-trained embeddings: 4. Train using pre-trained embeddings:
``` ```shell
python train.py -save_model data/model \ onmt_train -save_model data/model \
-batch_size 64 \ -batch_size 64 \
-layers 2 \ -layers 2 \
-rnn_size 200 \ -rnn_size 200 \
@ -61,14 +61,13 @@ python train.py -save_model data/model \
-data data/data -data data/data
``` ```
## How do I use the Transformer model?
## How do I use the Transformer model? Do you support multi-gpu?
The transformer model is very sensitive to hyperparameters. To run it The transformer model is very sensitive to hyperparameters. To run it
effectively you need to set a bunch of different options that mimic the Google effectively you need to set a bunch of different options that mimic the Google
setup. We have confirmed the following command can replicate their WMT results. setup. We have confirmed the following command can replicate their WMT results.
``` ```shell
python train.py -data /tmp/de2/data -save_model /tmp/extra \ python train.py -data /tmp/de2/data -save_model /tmp/extra \
-layers 6 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 \ -layers 6 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 \
-encoder_type transformer -decoder_type transformer -position_encoding \ -encoder_type transformer -decoder_type transformer -position_encoding \
@ -86,24 +85,34 @@ Here are what each of the parameters mean:
* `position_encoding`: add sinusoidal position encoding to each embedding * `position_encoding`: add sinusoidal position encoding to each embedding
* `optim adam`, `decay_method noam`, `warmup_steps 8000`: use special learning rate. * `optim adam`, `decay_method noam`, `warmup_steps 8000`: use special learning rate.
* `batch_type tokens`, `normalization tokens`, `accum_count 4`: batch and normalize based on number of tokens and not sentences. Compute gradients based on four batches. * `batch_type tokens`, `normalization tokens`, `accum_count 4`: batch and normalize based on number of tokens and not sentences. Compute gradients based on four batches.
- `label_smoothing 0.1`: use label smoothing loss. * `label_smoothing 0.1`: use label smoothing loss.
Multi GPU settings ## Do you support multi-gpu?
First you need to make sure you export CUDA_VISIBLE_DEVICES=0,1,2,3
If you want to use GPU id 1 and 3 of your OS, you will need to export CUDA_VISIBLE_DEVICES=1,3
* `world_size 4 gpu_ranks 0 1 2 3`: This will use 4 GPU on this node only.
If you want to use 2 nodes with 2 GPU each, you need to set -master_ip and master_port, and First you need to make sure you `export CUDA_VISIBLE_DEVICES=0,1,2,3`.
* `world_size 4 gpu_ranks 0 1`: on the first node
* `world_size 4 gpu_ranks 2 3`: on the second node
* `accum_count 2`: This will accumulate over 2 batches before updating parameters.
if you use a regular network card (1 Gbps) then we suggest to use a higher accum_count to minimize the inter-node communication. If you want to use GPU id 1 and 3 of your OS, you will need to `export CUDA_VISIBLE_DEVICES=1,3`
Both `-world_size` and `-gpu_ranks` need to be set. E.g. `-world_size 4 -gpu_ranks 0 1 2 3` will use 4 GPU on this node only.
If you want to use 2 nodes with 2 GPU each, you need to set `-master_ip` and `-master_port`, and
* `-world_size 4 -gpu_ranks 0 1`: on the first node
* `-world_size 4 -gpu_ranks 2 3`: on the second node
* `-accum_count 2`: This will accumulate over 2 batches before updating parameters.
if you use a regular network card (1 Gbps) then we suggest to use a higher `-accum_count` to minimize the inter-node communication.
**Note:**
When training on several GPUs, you can't have them in 'Exclusive' compute mode (`nvidia-smi -c 3`).
The multi-gpu setup relies on a Producer/Consumer setup. This setup means there will be `2<n_gpu> + 1` processes spawned, with 2 processes per GPU, one for model training and one (Consumer) that hosts a `Queue` of batches that will be processed next. The additional process is the Producer, creating batches and sending them to the Consumers. This setup is beneficial for both wall time and memory, since it loads data shards 'in advance', and does not require to load it for each GPU process.
## How can I ensemble Models at inference? ## How can I ensemble Models at inference?
You can specify several models in the translate.py command line: -model model1_seed1 model2_seed2 You can specify several models in the translate.py command line: -model model1_seed1 model2_seed2
Bear in mind that your models must share the same traget vocabulary. Bear in mind that your models must share the same target vocabulary.
## How can I weight different corpora at training? ## How can I weight different corpora at training?
@ -112,7 +121,8 @@ Bear in mind that your models must share the same traget vocabulary.
We introduced `-train_ids` which is a list of IDs that will be given to the preprocessed shards. We introduced `-train_ids` which is a list of IDs that will be given to the preprocessed shards.
E.g. we have two corpora : `parallel.en` and `parallel.de` + `from_backtranslation.en` `from_backtranslation.de`, we can pass the following in the `preprocess.py` command: E.g. we have two corpora : `parallel.en` and `parallel.de` + `from_backtranslation.en` `from_backtranslation.de`, we can pass the following in the `preprocess.py` command:
```
```shell
... ...
-train_src parallel.en from_backtranslation.en \ -train_src parallel.en from_backtranslation.en \
-train_tgt parallel.de from_backtranslation.de \ -train_tgt parallel.de from_backtranslation.de \
@ -120,19 +130,55 @@ E.g. we have two corpora : `parallel.en` and `parallel.de` + `from_backtranslat
-save_data my_data \ -save_data my_data \
... ...
``` ```
and it will dump `my_data.train_A.X.pt` based on `parallel.en`//`parallel.de` and `my_data.train_B.X.pt` based on `from_backtranslation.en`//`from_backtranslation.de`. and it will dump `my_data.train_A.X.pt` based on `parallel.en`//`parallel.de` and `my_data.train_B.X.pt` based on `from_backtranslation.en`//`from_backtranslation.de`.
### Training ### Training
We introduced `-data_ids` based on the same principle as above, as well as `-data_weights`, which is the list of the weight each corpus should have. We introduced `-data_ids` based on the same principle as above, as well as `-data_weights`, which is the list of the weight each corpus should have.
E.g. E.g.
```
```shell
... ...
-data my_data \ -data my_data \
-data_ids A B \ -data_ids A B \
-data_weights 1 7 \ -data_weights 1 7 \
... ...
``` ```
will mean that we'll look for `my_data.train_A.*.pt` and `my_data.train_B.*.pt`, and that when building batches, we'll take 1 example from corpus A, then 7 examples from corpus B, and so on. will mean that we'll look for `my_data.train_A.*.pt` and `my_data.train_B.*.pt`, and that when building batches, we'll take 1 example from corpus A, then 7 examples from corpus B, and so on.
**Warning**: This means that we'll load as many shards as we have `-data_ids`, in order to produce batches containing data from every corpus. It may be a good idea to reduce the `-shard_size` at preprocessing. **Warning**: This means that we'll load as many shards as we have `-data_ids`, in order to produce batches containing data from every corpus. It may be a good idea to reduce the `-shard_size` at preprocessing.
## Can I get word alignment while translating?
### Raw alignments from averaging Transformer attention heads
Currently, we support producing word alignment while translating for Transformer based models. Using `-report_align` when calling `translate.py` will output the inferred alignments in Pharaoh format. Those alignments are computed from an argmax on the average of the attention heads of the *second to last* decoder layer. The resulting alignment src-tgt (Pharaoh) will be pasted to the translation sentence, separated by ` ||| `.
Note: The *second to last* default behaviour was empirically determined. It is not the same as the paper (they take the *penultimate* layer), probably because of light differences in the architecture.
* alignments use the standard "Pharaoh format", where a pair `i-j` indicates the i<sub>th</sub> word of source language is aligned to j<sub>th</sub> word of target language.
* Example: {'src': 'das stimmt nicht !'; 'output': 'that is not true ! ||| 0-0 0-1 1-2 2-3 1-4 1-5 3-6'}
* Using the`-tgt` option when calling `translate.py`, we output alignments between the source and the gold target rather than the inferred target, assuming we're doing evaluation.
* To convert subword alignments to word alignments, or symetrize bidirectional alignments, please refer to the [lilt scripts](https://github.com/lilt/alignment-scripts).
### Supervised learning on a specific head
The quality of output alignments can be further improved by providing reference alignments while training. This will invoke multi-task learning on translation and alignment. This is an implementation based on the paper [Jointly Learning to Align and Translate with Transformer Models](https://arxiv.org/abs/1909.02074).
The data need to be preprocessed with the reference alignments in order to learn the supervised task.
When calling `preprocess.py`, add:
* `--train_align <path>`: path(s) to the training alignments in Pharaoh format
* `--valid_align <path>`: path to the validation set alignments in Pharaoh format (optional).
The reference alignment file(s) could be generated by [GIZA++](https://github.com/moses-smt/mgiza/) or [fast_align](https://github.com/clab/fast_align).
Note: There should be no blank lines in the alignment files provided.
Options to learn such alignments are:
* `-lambda_align`: set the value > 0.0 to enable joint align training, the paper suggests 0.05;
* `-alignment_layer`: indicate the index of the decoder layer;
* `-alignment_heads`: number of alignment heads for the alignment task - should be set to 1 for the supervised task, and preferably kept to default (or same as `num_heads`) for the average task;
* `-full_context_alignment`: do full context decoder pass (no future mask) when computing alignments. This will slow down the training (~12% in terms of tok/s) but will be beneficial to generate better alignment.

Просмотреть файл

@ -3,7 +3,7 @@
For this example, we will assume that we have run preprocess to For this example, we will assume that we have run preprocess to
create our datasets. For instance create our datasets. For instance
> python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 10000 -tgt_vocab_size 10000 > onmt_preprocess -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 10000 -tgt_vocab_size 10000

Просмотреть файл

@ -35,8 +35,8 @@ For CNN-DM we follow See et al. [2] and additionally truncate the source length
(1) CNN-DM (1) CNN-DM
``` ```bash
python preprocess.py -train_src data/cnndm/train.txt.src \ onmt_preprocess -train_src data/cnndm/train.txt.src \
-train_tgt data/cnndm/train.txt.tgt.tagged \ -train_tgt data/cnndm/train.txt.tgt.tagged \
-valid_src data/cnndm/val.txt.src \ -valid_src data/cnndm/val.txt.src \
-valid_tgt data/cnndm/val.txt.tgt.tagged \ -valid_tgt data/cnndm/val.txt.tgt.tagged \
@ -52,8 +52,8 @@ python preprocess.py -train_src data/cnndm/train.txt.src \
(2) Gigaword (2) Gigaword
``` ```bash
python preprocess.py -train_src data/giga/train.article.txt \ onmt_preprocess -train_src data/giga/train.article.txt \
-train_tgt data/giga/train.title.txt \ -train_tgt data/giga/train.title.txt \
-valid_src data/giga/valid.article.txt \ -valid_src data/giga/valid.article.txt \
-valid_tgt data/giga/valid.title.txt \ -valid_tgt data/giga/valid.title.txt \
@ -87,8 +87,8 @@ We additionally set the maximum norm of the gradient to 2, and renormalize if th
(1) CNN-DM (1) CNN-DM
``` ```bash
python train.py -save_model models/cnndm \ onmt_train -save_model models/cnndm \
-data data/cnndm/CNNDM \ -data data/cnndm/CNNDM \
-copy_attn \ -copy_attn \
-global_attention mlp \ -global_attention mlp \
@ -116,8 +116,8 @@ python train.py -save_model models/cnndm \
The following script trains the transformer model on CNN-DM The following script trains the transformer model on CNN-DM
``` ```bash
python -u train.py -data data/cnndm/CNNDM \ onmt_train -data data/cnndm/CNNDM \
-save_model models/cnndm \ -save_model models/cnndm \
-layers 4 \ -layers 4 \
-rnn_size 512 \ -rnn_size 512 \
@ -152,7 +152,7 @@ python -u train.py -data data/cnndm/CNNDM \
Gigaword can be trained equivalently. As a baseline, we show a model trained with the following command: Gigaword can be trained equivalently. As a baseline, we show a model trained with the following command:
``` ```
python train.py -data data/giga/GIGA \ onmt_train -data data/giga/GIGA \
-save_model models/giga \ -save_model models/giga \
-copy_attn \ -copy_attn \
-reuse_copy_attn \ -reuse_copy_attn \
@ -177,7 +177,7 @@ During inference, we use beam-search with a beam-size of 10. We also added speci
(1) CNN-DM (1) CNN-DM
``` ```
python translate.py -gpu X \ onmt_translate -gpu X \
-batch_size 20 \ -batch_size 20 \
-beam_size 10 \ -beam_size 10 \
-model models/cnndm... \ -model models/cnndm... \
@ -221,8 +221,6 @@ For evaluation of large test sets such as Gigaword, we use the a parallel python
### Scores and Models ### Scores and Models
The website generator has trouble rendering tables, if you can't read the results, please go [here](https://github.com/OpenNMT/OpenNMT-py/blob/master/docs/source/Summarization.md) for correct format.
#### CNN-DM #### CNN-DM
| Model Type | Model | R1 R | R1 P | R1 F | R2 R | R2 P | R2 F | RL R | RL P | RL F | | Model Type | Model | R1 R | R1 P | R1 F | R2 R | R2 P | R2 F | RL R | RL P | RL F |
@ -231,7 +229,7 @@ The website generator has trouble rendering tables, if you can't read the result
| Pointer-Generator [2] | [link](https://github.com/abisee/pointer-generator) | 37.76 | 37.60| 36.44| 16.31| 16.12| 15.66| 34.66| 34.46| 33.42 | | Pointer-Generator [2] | [link](https://github.com/abisee/pointer-generator) | 37.76 | 37.60| 36.44| 16.31| 16.12| 15.66| 34.66| 34.46| 33.42 |
| OpenNMT BRNN (1 layer, emb 128, hid 512) | [link](https://s3.amazonaws.com/opennmt-models/Summary/ada6_bridge_oldcopy_tagged_acc_54.17_ppl_11.17_e20.pt) | 40.90| 40.20| 39.02| 17.91| 17.99| 17.25| 37.76 | 37.18| 36.05 | | OpenNMT BRNN (1 layer, emb 128, hid 512) | [link](https://s3.amazonaws.com/opennmt-models/Summary/ada6_bridge_oldcopy_tagged_acc_54.17_ppl_11.17_e20.pt) | 40.90| 40.20| 39.02| 17.91| 17.99| 17.25| 37.76 | 37.18| 36.05 |
| OpenNMT BRNN (1 layer, emb 128, hid 512, shared embeddings) | [link](https://s3.amazonaws.com/opennmt-models/Summary/ada6_bridge_oldcopy_tagged_share_acc_54.50_ppl_10.89_e20.pt) | 38.59 | 40.60 | 37.97 | 16.75 | 17.93 | 16.59 | 35.67 | 37.60 | 35.13 | | OpenNMT BRNN (1 layer, emb 128, hid 512, shared embeddings) | [link](https://s3.amazonaws.com/opennmt-models/Summary/ada6_bridge_oldcopy_tagged_share_acc_54.50_ppl_10.89_e20.pt) | 38.59 | 40.60 | 37.97 | 16.75 | 17.93 | 16.59 | 35.67 | 37.60 | 35.13 |
| OpenNMT BRNN (2 layer, emb 256, hid 1024) | [link](https://s3.amazonaws.com/opennmt-models/Summary/ada6_bridge_oldcopy_tagged_larger_acc_54.84_ppl_10.58_e17.ptt) | 40.41 | 40.94 | 39.12 | 17.76 | 18.38 | 17.35 | 37.27 | 37.83 | 36.12 | | OpenNMT BRNN (2 layer, emb 256, hid 1024) | [link](https://s3.amazonaws.com/opennmt-models/Summary/ada6_bridge_oldcopy_tagged_larger_acc_54.84_ppl_10.58_e17.pt) | 40.41 | 40.94 | 39.12 | 17.76 | 18.38 | 17.35 | 37.27 | 37.83 | 36.12 |
| OpenNMT Transformer | [link](https://s3.amazonaws.com/opennmt-models/sum_transformer_model_acc_57.25_ppl_9.22_e16.pt) | 40.31 | 41.09 | 39.25 | 17.97 | 18.46 | 17.54 | 37.41 | 38.18 | 36.45 | | OpenNMT Transformer | [link](https://s3.amazonaws.com/opennmt-models/sum_transformer_model_acc_57.25_ppl_9.22_e16.pt) | 40.31 | 41.09 | 39.25 | 17.97 | 18.46 | 17.54 | 37.41 | 38.18 | 36.45 |

Просмотреть файл

@ -49,7 +49,8 @@ extensions = ['sphinx.ext.autodoc',
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinxcontrib.mermaid', 'sphinxcontrib.mermaid',
'sphinxcontrib.bibtex', 'sphinxcontrib.bibtex',
'sphinxarg.ext'] 'sphinxarg.ext',
'sphinx_markdown_tables']
# Show base classes # Show base classes
autodoc_default_options = { autodoc_default_options = {

Просмотреть файл

@ -17,19 +17,19 @@ Step 1. Preprocess the data.
```bash ```bash
for l in en de; do for f in data/multi30k/*.$l; do if [[ "$f" != *"test"* ]]; then sed -i "$ d" $f; fi; done; done for l in en de; do for f in data/multi30k/*.$l; do if [[ "$f" != *"test"* ]]; then sed -i "$ d" $f; fi; done; done
for l in en de; do for f in data/multi30k/*.$l; do perl tools/tokenizer.perl -a -no-escape -l $l -q < $f > $f.atok; done; done for l in en de; do for f in data/multi30k/*.$l; do perl tools/tokenizer.perl -a -no-escape -l $l -q < $f > $f.atok; done; done
python preprocess.py -train_src data/multi30k/train.en.atok -train_tgt data/multi30k/train.de.atok -valid_src data/multi30k/val.en.atok -valid_tgt data/multi30k/val.de.atok -save_data data/multi30k.atok.low -lower onmt_preprocess -train_src data/multi30k/train.en.atok -train_tgt data/multi30k/train.de.atok -valid_src data/multi30k/val.en.atok -valid_tgt data/multi30k/val.de.atok -save_data data/multi30k.atok.low -lower
``` ```
Step 2. Train the model. Step 2. Train the model.
```bash ```bash
python train.py -data data/multi30k.atok.low -save_model multi30k_model -gpu_ranks 0 onmt_train -data data/multi30k.atok.low -save_model multi30k_model -gpu_ranks 0
``` ```
Step 3. Translate sentences. Step 3. Translate sentences.
```bash ```bash
python translate.py -gpu 0 -model multi30k_model_*_e13.pt -src data/multi30k/test2016.en.atok -tgt data/multi30k/test2016.de.atok -replace_unk -verbose -output multi30k.test.pred.atok onmt_translate -gpu 0 -model multi30k_model_*_e13.pt -src data/multi30k/test2016.en.atok -tgt data/multi30k/test2016.de.atok -replace_unk -verbose -output multi30k.test.pred.atok
``` ```
And evaluate And evaluate

Просмотреть файл

@ -0,0 +1,94 @@
# Gated Graph Sequence Neural Networks
Graph-to-sequence networks allow information represtable as a graph (such as an annotated NLP sentence or computer code structure as an AST) to be connected to a sequence generator to produce output which can benefit from the graph structure of the input.
The training option `-encoder_type ggnn` implements a GGNN (Gated Graph Neural Network) based on github.com/JamesChuanggg/ggnn.pytorch.git which is based on the paper "Gated Graph Sequence Neural Networks" by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel.
The ggnn encoder is used for program equivalence proof generation in the paper <a href="https://arxiv.org/abs/2002.06799">Equivalence of Dataflow Graphs via Rewrite Rules Using a Graph-to-Sequence Neural Model</a>. That paper shows the benefit of the graph-to-sequence model over a sequence-to-sequence model for this problem which can be well represented with graphical input. The integration of the ggnn network into the <a href="https://github.com/OpenNMT/OpenNMT-py/">OpenNMT-py</a> system supports attention on the nodes as well as a copy mechanism.
### Dependencies
* There are no additional dependencies beyond the rnn-to-rnn sequeence2sequence requirements.
### Quick Start
To get started, we provide a toy graph-to-sequence example. We assume that the working directory is `OpenNMT-py` throughout this document.
0) Download the data to a sibling directory.
```
cd ..
git clone https://github.com/SteveKommrusch/OpenNMT-py-ggnn-example
source OpenNMT-py-ggnn-example/env.sh
cd OpenNMT-py
```
1) Preprocess the data.
```
python preprocess.py -train_src $data_path/src-train.txt -train_tgt $data_path/tgt-train.txt -valid_src $data_path/src-val.txt -valid_tgt $data_path/tgt-val.txt -src_seq_length 1000 -tgt_seq_length 30 -src_vocab $data_path/srcvocab.txt -tgt_vocab $data_path/tgtvocab.txt -dynamic_dict -save_data $data_path/final 2>&1 > $data_path/preprocess.out
```
2) Train the model.
```
python train.py -data $data_path/final -encoder_type ggnn -layers 2 -decoder_type rnn -rnn_size 256 -learning_rate 0.1 -start_decay_steps 5000 -learning_rate_decay 0.8 -global_attention general -batch_size 32 -word_vec_size 256 -bridge -train_steps 10000 -gpu_ranks 0 -save_checkpoint_steps 5000 -save_model $data_path/final-model -src_vocab $data_path/srcvocab.txt -n_edge_types 9 -state_dim 256 -n_steps 10 -n_node 64 > $data_path/train.final.out
```
3) Translate the graph of 2 equivalent linear algebra expressions into the axiom list which proves them equivalent.
```
python translate.py -model $data_path/final-model_step_10000.pt -src $data_path/src-test.txt -beam_size 5 -n_best 5 -gpu 0 -output $data_path/pred-test_beam5.txt -dynamic_dict 2>&1 > $data_path/translate5.out
```
### Graph data format
The GGNN implementation leverages the sequence processing and vocabulary
interface of OpenNMT. Each graph is provided on an input line, much like
a sentence is provided on an input line. A graph nearal network input line
includes `sentence tokens`, `feature values`, and `edges` separated by
`<EOT>` (end of tokens) tokens. Below is example of the input for a pair
of algebraic equations structured as a graph:
```
Sentence tokens Feature values Edges
--------------- ------------------ -------------------------------------------------------
- - - 0 a a b b <EOT> 0 1 2 3 4 4 2 3 12 <EOT> 0 2 1 3 2 4 , 0 6 1 7 2 5 , 0 4 , 0 5 , , , , 8 0 , 8 1
```
The equations being represented are `((a - a) - b)` and `(0 - b)`, the
`sentence tokens` of which are provided before the first `<EOT>`. After
the first `<EOT>`, the `features values` are provided. These are extra
flags with information on each node in the graph. In this case, the 8
sentence tokens have feature flags ranging from 0 to 4; the 9th feature
flag defines a 9th node in the graph which does not have sentence token
information, just feature data. Nodes with any non-number flag (such as
`-` or `.`) will not have a feature added. Multiple groups of features
can be provided by using the `,` delimiter between the first and second
'<EOT>' tokens. After the second `<EOT>` token, edge information is provided.
Edge data is given as node pairs, hence `<EOT> 0 2 1 3` indicates that there
are edges from node 0 to node 2 and from node 1 to node 3. The GGNN supports
multiple edge types (which result mathematically in multiple weight matrices
for the model) and the edge types are separated by `,` tokens after the
second `<EOT>` token.
Note that the source vocabulary file needs to include the '<EOT>' token,
the ',' token, and all of the numbers used for feature flags and node
identifiers in the edge list.
### Options
* `-rnn_type (str)`: style of recurrent unit to use, one of [LSTM]
* `-state_dim (int)`: Number of state dimensions in nodes
* `-n_edge_types (int)`: Number of edge types
* `-bidir_edges (bool)`: True if reverse edges should be automatically created
* `-n_node (int)`: Max nodes in graph
* `-bridge_extra_node (bool)`: True indicates only the vector from the 1st extra node (after token listing) should be used for decoder initialization; False indicates all node vectors should be averaged together for decoder initialization
* `-n_steps (int)`: Steps to advance graph encoder for stabilization
* `-src_vocab (int)`: Path to source vocabulary
### Acknowledgement
This gated graph neural network is leveraged from github.com/JamesChuanggg/ggnn.pytorch.git which is based on the paper "Gated Graph Sequence Neural Networks" by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel.

Просмотреть файл

@ -27,31 +27,51 @@ Im2Text consists of four commands:
0) Download the data. 0) Download the data.
``` ```bash
wget -O data/im2text.tgz http://lstm.seas.harvard.edu/latex/im2text_small.tgz; tar zxf data/im2text.tgz -C data/ wget -O data/im2text.tgz http://lstm.seas.harvard.edu/latex/im2text_small.tgz; tar zxf data/im2text.tgz -C data/
``` ```
1) Preprocess the data. 1) Preprocess the data.
``` ```bash
python preprocess.py -data_type img -src_dir data/im2text/images/ -train_src data/im2text/src-train.txt \ onmt_preprocess -data_type img \
-src_dir data/im2text/images/ \
-train_src data/im2text/src-train.txt \
-train_tgt data/im2text/tgt-train.txt -valid_src data/im2text/src-val.txt \ -train_tgt data/im2text/tgt-train.txt -valid_src data/im2text/src-val.txt \
-valid_tgt data/im2text/tgt-val.txt -save_data data/im2text/demo \ -valid_tgt data/im2text/tgt-val.txt -save_data data/im2text/demo \
-tgt_seq_length 150 -tgt_words_min_frequency 2 -shard_size 500 -image_channel_size 1 -tgt_seq_length 150 \
-tgt_words_min_frequency 2 \
-shard_size 500 \
-image_channel_size 1
``` ```
2) Train the model. 2) Train the model.
``` ```bash
python train.py -model_type img -data data/im2text/demo -save_model demo-model -gpu_ranks 0 -batch_size 20 \ onmt_train -model_type img \
-max_grad_norm 20 -learning_rate 0.1 -word_vec_size 80 -encoder_type brnn -image_channel_size 1 -data data/im2text/demo \
-save_model demo-model \
-gpu_ranks 0 \
-batch_size 20 \
-max_grad_norm 20 \
-learning_rate 0.1 \
-word_vec_size 80 \
-encoder_type brnn \
-image_channel_size 1
``` ```
3) Translate the images. 3) Translate the images.
``` ```bash
python translate.py -data_type img -model demo-model_acc_x_ppl_x_e13.pt -src_dir data/im2text/images \ onmt_translate -data_type img \
-src data/im2text/src-test.txt -output pred.txt -max_length 150 -beam_size 5 -gpu 0 -verbose -model demo-model_acc_x_ppl_x_e13.pt \
-src_dir data/im2text/images \
-src data/im2text/src-test.txt \
-output pred.txt \
-max_length 150 \
-beam_size 5 \
-gpu 0 \
-verbose
``` ```
The above dataset is sampled from the [im2latex-100k-dataset](http://lstm.seas.harvard.edu/latex/im2text.tgz). We provide a trained model [[link]](http://lstm.seas.harvard.edu/latex/py-model.pt) on this dataset. The above dataset is sampled from the [im2latex-100k-dataset](http://lstm.seas.harvard.edu/latex/im2text.tgz). We provide a trained model [[link]](http://lstm.seas.harvard.edu/latex/py-model.pt) on this dataset.

Просмотреть файл

@ -6,20 +6,22 @@ This portal provides a detailed documentation of the OpenNMT toolkit. It describ
## Installation ## Installation
Install from `pip`:
1\. [Install PyTorch](http://pytorch.org/) Install `OpenNMT-py` from `pip`:
2\. Clone the OpenNMT-py repository:
```bash ```bash
git clone https://github.com/OpenNMT/OpenNMT-py pip install OpenNMT-py
cd OpenNMT-py
``` ```
3\. Install required libraries or from the sources:
```bash ```bash
pip install -r requirements.txt git clone https://github.com/OpenNMT/OpenNMT-py.git
cd OpenNMT-py
python setup.py install
```
*(Optionnal)* some advanced features (e.g. working audio, image or pretrained models) requires extra packages, you can install it with:
```bash
pip install -r requirements.opt.txt
``` ```
And you are ready to go! Take a look at the [quickstart](quickstart) to familiarize yourself with the main training workflow. And you are ready to go! Take a look at the [quickstart](quickstart) to familiarize yourself with the main training workflow.

Просмотреть файл

@ -25,9 +25,9 @@ Decoding Strategies
.. autoclass:: onmt.translate.BeamSearch .. autoclass:: onmt.translate.BeamSearch
:members: :members:
.. autofunction:: onmt.translate.random_sampling.sample_with_temperature .. autofunction:: onmt.translate.greedy_search.sample_with_temperature
.. autoclass:: onmt.translate.RandomSampling .. autoclass:: onmt.translate.GreedySearch
:members: :members:
Scoring Scoring

Просмотреть файл

@ -2,6 +2,6 @@ Preprocess
========== ==========
.. argparse:: .. argparse::
:filename: ../preprocess.py :filename: ../onmt/bin/preprocess.py
:func: _get_parser :func: _get_parser
:prog: preprocess.py :prog: preprocess.py

Просмотреть файл

@ -2,6 +2,6 @@ Server
========= =========
.. argparse:: .. argparse::
:filename: ../server.py :filename: ../onmt/bin/server.py
:func: _get_parser :func: _get_parser
:prog: server.py :prog: server.py

Просмотреть файл

@ -2,6 +2,6 @@ Train
===== =====
.. argparse:: .. argparse::
:filename: ../train.py :filename: ../onmt/bin/train.py
:func: _get_parser :func: _get_parser
:prog: train.py :prog: train.py

Просмотреть файл

@ -2,6 +2,6 @@ Translate
========= =========
.. argparse:: .. argparse::
:filename: ../translate.py :filename: ../onmt/bin/translate.py
:func: _get_parser :func: _get_parser
:prog: translate.py :prog: translate.py

Просмотреть файл

@ -6,7 +6,7 @@
### Step 1: Preprocess the data ### Step 1: Preprocess the data
```bash ```bash
python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo onmt_preprocess -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/demo
``` ```
We will be working with some example data in `data/` folder. We will be working with some example data in `data/` folder.
@ -30,7 +30,7 @@ Federal Master Trainer and Senior Instructor of the Italian Federation of Aerobi
### Step 2: Train the model ### Step 2: Train the model
```bash ```bash
python train.py -data data/demo -save_model demo-model onmt_train -data data/demo -save_model demo-model
``` ```
The main train command is quite simple. Minimally it takes a data file The main train command is quite simple. Minimally it takes a data file
@ -44,7 +44,7 @@ To know more about distributed training on single or multi nodes, read the FAQ s
### Step 3: Translate ### Step 3: Translate
```bash ```bash
python translate.py -model demo-model_XYZ.pt -src data/src-test.txt -output pred.txt -replace_unk -verbose onmt_translate -model demo-model_XYZ.pt -src data/src-test.txt -output pred.txt -replace_unk -verbose
``` ```
Now you have a model which you can use to predict on new data. We do this by running beam search. This will output predictions into `pred.txt`. Now you have a model which you can use to predict on new data. We do this by running beam search. This will output predictions into `pred.txt`.

Просмотреть файл

@ -20,6 +20,22 @@
year={2016} year={2016}
} }
@inproceedings{Li2016
author = {Yujia Li and
Daniel Tarlow and
Marc Brockschmidt and
Richard S. Zemel},
title = {Gated Graph Sequence Neural Networks},
booktitle = {4th International Conference on Learning Representations, {ICLR} 2016,
San Juan, Puerto Rico, May 2-4, 2016, Conference Track Proceedings},
year = {2016},
crossref = {DBLP:conf/iclr/2016},
url = {http://arxiv.org/abs/1511.05493},
timestamp = {Thu, 25 Jul 2019 14:25:40 +0200},
biburl = {https://dblp.org/rec/journals/corr/LiTBZ15.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
@inproceedings{Bahdanau2015, @inproceedings{Bahdanau2015,
archivePrefix = {arXiv}, archivePrefix = {arXiv},
arxivId = {1409.0473}, arxivId = {1409.0473},
@ -435,3 +451,33 @@ year = {2015}
biburl = {https://dblp.org/rec/bib/journals/corr/MartinsA16}, biburl = {https://dblp.org/rec/bib/journals/corr/MartinsA16},
bibsource = {dblp computer science bibliography, https://dblp.org} bibsource = {dblp computer science bibliography, https://dblp.org}
} }
@inproceedings{garg2019jointly,
title = {Jointly Learning to Align and Translate with Transformer Models},
author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias},
booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)},
address = {Hong Kong},
month = {November},
url = {https://arxiv.org/abs/1909.02074},
year = {2019},
}
@inproceedings{DeeperTransformer,
title = "Learning Deep Transformer Models for Machine Translation",
author = "Wang, Qiang and
Li, Bei and
Xiao, Tong and
Zhu, Jingbo and
Li, Changliang and
Wong, Derek F. and
Chao, Lidia S.",
booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
month = jul,
year = "2019",
address = "Florence, Italy",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/P19-1176",
doi = "10.18653/v1/P19-1176",
pages = "1810--1822",
abstract = "Transformer is the state-of-the-art model in recent machine translation evaluations. Two strands of research are promising to improve models of this kind: the first uses wide networks (a.k.a. Transformer-Big) and has been the de facto standard for development of the Transformer system, and the other uses deeper language representation but faces the difficulty arising from learning deep networks. Here, we continue the line of research on the latter. We claim that a truly deep Transformer model can surpass the Transformer-Big counterpart by 1) proper use of layer normalization and 2) a novel way of passing the combination of previous layers to the next. On WMT{'}16 English-German and NIST OpenMT{'}12 Chinese-English tasks, our deep system (30/25-layer encoder) outperforms the shallow Transformer-Big/Base baseline (6-layer encoder) by 0.4-2.4 BLEU points. As another bonus, the deep model is 1.6X smaller in size and 3X faster in training than Transformer-Big.",
}

Просмотреть файл

@ -23,19 +23,19 @@ wget -O data/speech.tgz http://lstm.seas.harvard.edu/latex/speech.tgz; tar zxf d
1) Preprocess the data. 1) Preprocess the data.
``` ```
python preprocess.py -data_type audio -src_dir data/speech/an4_dataset -train_src data/speech/src-train.txt -train_tgt data/speech/tgt-train.txt -valid_src data/speech/src-val.txt -valid_tgt data/speech/tgt-val.txt -shard_size 300 -save_data data/speech/demo onmt_preprocess -data_type audio -src_dir data/speech/an4_dataset -train_src data/speech/src-train.txt -train_tgt data/speech/tgt-train.txt -valid_src data/speech/src-val.txt -valid_tgt data/speech/tgt-val.txt -shard_size 300 -save_data data/speech/demo
``` ```
2) Train the model. 2) Train the model.
``` ```
python train.py -model_type audio -enc_rnn_size 512 -dec_rnn_size 512 -audio_enc_pooling 1,1,2,2 -dropout 0 -enc_layers 4 -dec_layers 1 -rnn_type LSTM -data data/speech/demo -save_model demo-model -global_attention mlp -gpu_ranks 0 -batch_size 8 -optim adam -max_grad_norm 100 -learning_rate 0.0003 -learning_rate_decay 0.8 -train_steps 100000 onmt_train -model_type audio -enc_rnn_size 512 -dec_rnn_size 512 -audio_enc_pooling 1,1,2,2 -dropout 0 -enc_layers 4 -dec_layers 1 -rnn_type LSTM -data data/speech/demo -save_model demo-model -global_attention mlp -gpu_ranks 0 -batch_size 8 -optim adam -max_grad_norm 100 -learning_rate 0.0003 -learning_rate_decay 0.8 -train_steps 100000
``` ```
3) Translate the speechs. 3) Translate the speechs.
``` ```
python translate.py -data_type audio -model demo-model_acc_x_ppl_x_e13.pt -src_dir data/speech/an4_dataset -src data/speech/src-val.txt -output pred.txt -gpu 0 -verbose onmt_translate -data_type audio -model demo-model_acc_x_ppl_x_e13.pt -src_dir data/speech/an4_dataset -src data/speech/src-val.txt -output pred.txt -gpu 0 -verbose
``` ```

Просмотреть файл

@ -158,19 +158,19 @@ Preprocess the data with
.. code-block:: bash .. code-block:: bash
python preprocess.py -data_type vec -train_src $YT2T/yt2t_train_files.txt -src_dir $YT2T/r152/ -train_tgt $YT2T/yt2t_train_cap.txt -valid_src $YT2T/yt2t_val_files.txt -valid_tgt $YT2T/yt2t_val_cap.txt -save_data data/yt2t --shard_size 1000 onmt_preprocess -data_type vec -train_src $YT2T/yt2t_train_files.txt -src_dir $YT2T/r152/ -train_tgt $YT2T/yt2t_train_cap.txt -valid_src $YT2T/yt2t_val_files.txt -valid_tgt $YT2T/yt2t_val_cap.txt -save_data data/yt2t --shard_size 1000
Train with Train with
.. code-block:: bash .. code-block:: bash
python train.py -data data/yt2t -save_model yt2t-model -world_size 2 -gpu_ranks 0 1 -model_type vec -batch_size 64 -train_steps 10000 -valid_steps 500 -save_checkpoint_steps 500 -encoder_type brnn -optim adam -learning_rate .0001 -feat_vec_size 2048 onmt_train -data data/yt2t -save_model yt2t-model -world_size 2 -gpu_ranks 0 1 -model_type vec -batch_size 64 -train_steps 10000 -valid_steps 500 -save_checkpoint_steps 500 -encoder_type brnn -optim adam -learning_rate .0001 -feat_vec_size 2048
Translate with Translate with
.. code-block:: .. code-block::
python translate.py -model yt2t-model_step_7200.pt -src $YT2T/yt2t_test_files.txt -output pred.txt -verbose -data_type vec -src_dir $YT2T/r152 -gpu 0 -batch_size 10 onmt_translate -model yt2t-model_step_7200.pt -src $YT2T/yt2t_test_files.txt -output pred.txt -verbose -data_type vec -src_dir $YT2T/r152 -gpu 0 -batch_size 10
.. note:: .. note::
@ -385,13 +385,13 @@ old data. Do this, i.e. ``rm data/yt2t.*.pt.``)
.. code-block:: bash .. code-block:: bash
python preprocess.py -data_type vec -train_src $YT2T/yt2t_train_files.txt -src_dir $YT2T/r152/ -train_tgt $YT2T/yt2t_train_cap.txt -valid_src $YT2T/yt2t_val_files.txt -valid_tgt $YT2T/yt2t_val_cap.txt -save_data data/yt2t --shard_size 1000 --src_seq_length 50 --tgt_seq_length 20 onmt_preprocess -data_type vec -train_src $YT2T/yt2t_train_files.txt -src_dir $YT2T/r152/ -train_tgt $YT2T/yt2t_train_cap.txt -valid_src $YT2T/yt2t_val_files.txt -valid_tgt $YT2T/yt2t_val_cap.txt -save_data data/yt2t --shard_size 1000 --src_seq_length 50 --tgt_seq_length 20
Delete the old checkpoints and train a transformer model on this data. Delete the old checkpoints and train a transformer model on this data.
.. code-block:: bash .. code-block:: bash
rm -r yt2t-model_step_*.pt; python train.py -data data/yt2t -save_model yt2t-model -world_size 2 -gpu_ranks 0 1 -model_type vec -batch_size 64 -train_steps 8000 -valid_steps 400 -save_checkpoint_steps 400 -optim adam -learning_rate .0001 -feat_vec_size 2048 -layers 4 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 -encoder_type transformer -decoder_type transformer -position_encoding -dropout 0.3 -param_init 0 -param_init_glorot -report_every 400 --share_decoder_embedding --seed 7000 rm -r yt2t-model_step_*.pt; onmt_train -data data/yt2t -save_model yt2t-model -world_size 2 -gpu_ranks 0 1 -model_type vec -batch_size 64 -train_steps 8000 -valid_steps 400 -save_checkpoint_steps 400 -optim adam -learning_rate .0001 -feat_vec_size 2048 -layers 4 -rnn_size 512 -word_vec_size 512 -transformer_ff 2048 -heads 8 -encoder_type transformer -decoder_type transformer -position_encoding -dropout 0.3 -param_init 0 -param_init_glorot -report_every 400 --share_decoder_embedding --seed 7000
Note we use the hyperparameters described in the paper. Note we use the hyperparameters described in the paper.
We estimate the length of 20 epochs with ``-train_steps``. Note that this depends on We estimate the length of 20 epochs with ``-train_steps``. Note that this depends on
@ -502,7 +502,7 @@ although this too is not mentioned. You can reproduce our early-stops with these
for file in $( ls -1v yt2t-model_step*.pt ) for file in $( ls -1v yt2t-model_step*.pt )
do do
echo $file echo $file
python translate.py -model $file -src $YT2T/yt2t_val_folded_files.txt -output pred.txt -verbose -data_type vec -src_dir $YT2T/r152 -gpu 0 -batch_size 16 -max_length 20 >/dev/null 2>/dev/null onmt_translate -model $file -src $YT2T/yt2t_val_folded_files.txt -output pred.txt -verbose -data_type vec -src_dir $YT2T/r152 -gpu 0 -batch_size 16 -max_length 20 >/dev/null 2>/dev/null
echo -e "$file\n" >> results.txt echo -e "$file\n" >> results.txt
python coco.py -s val >> results.txt python coco.py -s val >> results.txt
echo -e "\n\n" >> results.txt echo -e "\n\n" >> results.txt
@ -519,7 +519,7 @@ although this too is not mentioned. You can reproduce our early-stops with these
metric=$(echo $line | awk '{print $1}') metric=$(echo $line | awk '{print $1}')
step=$(echo $line | awk '{print $NF}') step=$(echo $line | awk '{print $NF}')
echo $metric early stopped @ $step | tee -a test_results.txt echo $metric early stopped @ $step | tee -a test_results.txt
python translate.py -model "yt2t-model_step_${step}.pt" -src $YT2T/yt2t_test_files.txt -output pred.txt -data_type vec -src_dir $YT2T/r152 -gpu 0 -batch_size 16 -max_length 20 >/dev/null 2>/dev/null onmt_translate -model "yt2t-model_step_${step}.pt" -src $YT2T/yt2t_test_files.txt -output pred.txt -data_type vec -src_dir $YT2T/r152 -gpu 0 -batch_size 16 -max_length 20 >/dev/null 2>/dev/null
python coco.py -s 'test' >> test_results.txt python coco.py -s 'test' >> test_results.txt
echo -e "\n\n" >> test_results.txt echo -e "\n\n" >> test_results.txt
fi fi

Просмотреть файл

@ -17,4 +17,4 @@ sys.modules["onmt.Optim"] = onmt.utils.optimizers
__all__ = [onmt.inputters, onmt.encoders, onmt.decoders, onmt.models, __all__ = [onmt.inputters, onmt.encoders, onmt.decoders, onmt.models,
onmt.utils, onmt.modules, "Trainer"] onmt.utils, onmt.modules, "Trainer"]
__version__ = "0.9.1" __version__ = "1.1.1"

Просмотреть файл

Просмотреть файл

@ -0,0 +1,54 @@
#!/usr/bin/env python
import argparse
import torch
def average_models(model_files, fp32=False):
vocab = None
opt = None
avg_model = None
avg_generator = None
for i, model_file in enumerate(model_files):
m = torch.load(model_file, map_location='cpu')
model_weights = m['model']
generator_weights = m['generator']
if fp32:
for k, v in model_weights.items():
model_weights[k] = v.float()
for k, v in generator_weights.items():
generator_weights[k] = v.float()
if i == 0:
vocab, opt = m['vocab'], m['opt']
avg_model = model_weights
avg_generator = generator_weights
else:
for (k, v) in avg_model.items():
avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1)
for (k, v) in avg_generator.items():
avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1)
final = {"vocab": vocab, "opt": opt, "optim": None,
"generator": avg_generator, "model": avg_model}
return final
def main():
parser = argparse.ArgumentParser(description="")
parser.add_argument("-models", "-m", nargs="+", required=True,
help="List of models")
parser.add_argument("-output", "-o", required=True,
help="Output file")
parser.add_argument("-fp32", "-f", action="store_true",
help="Cast params to float32")
opt = parser.parse_args()
final = average_models(opt.models, opt.fp32)
torch.save(final, opt.output)
if __name__ == "__main__":
main()

315
opennmt/onmt/bin/preprocess.py Executable file
Просмотреть файл

@ -0,0 +1,315 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Pre-process Data / features files and build vocabulary
"""
import codecs
import glob
import gc
import torch
from collections import Counter, defaultdict
from onmt.utils.logging import init_logger, logger
from onmt.utils.misc import split_corpus
import onmt.inputters as inputters
import onmt.opts as opts
from onmt.utils.parse import ArgumentParser
from onmt.inputters.inputter import _build_fields_vocab,\
_load_vocab, \
old_style_vocab, \
load_old_vocab
from functools import partial
from multiprocessing import Pool
def check_existing_pt_files(opt, corpus_type, ids, existing_fields):
""" Check if there are existing .pt files to avoid overwriting them """
existing_shards = []
for maybe_id in ids:
if maybe_id:
shard_base = corpus_type + "_" + maybe_id
else:
shard_base = corpus_type
pattern = opt.save_data + '.{}.*.pt'.format(shard_base)
if glob.glob(pattern):
if opt.overwrite:
maybe_overwrite = ("will be overwritten because "
"`-overwrite` option is set.")
else:
maybe_overwrite = ("won't be overwritten, pass the "
"`-overwrite` option if you want to.")
logger.warning("Shards for corpus {} already exist, {}"
.format(shard_base, maybe_overwrite))
existing_shards += [maybe_id]
return existing_shards
def process_one_shard(corpus_params, params):
corpus_type, fields, src_reader, tgt_reader, align_reader, opt,\
existing_fields, src_vocab, tgt_vocab = corpus_params
i, (src_shard, tgt_shard, align_shard, maybe_id, filter_pred) = params
# create one counter per shard
sub_sub_counter = defaultdict(Counter)
assert len(src_shard) == len(tgt_shard)
logger.info("Building shard %d." % i)
src_data = {"reader": src_reader, "data": src_shard, "dir": opt.src_dir}
tgt_data = {"reader": tgt_reader, "data": tgt_shard, "dir": None}
align_data = {"reader": align_reader, "data": align_shard, "dir": None}
_readers, _data, _dir = inputters.Dataset.config(
[('src', src_data), ('tgt', tgt_data), ('align', align_data)])
dataset = inputters.Dataset(
fields, readers=_readers, data=_data, dirs=_dir,
sort_key=inputters.str2sortkey[opt.data_type],
filter_pred=filter_pred,
corpus_id=maybe_id
)
if corpus_type == "train" and existing_fields is None:
for ex in dataset.examples:
sub_sub_counter['corpus_id'].update(
["train" if maybe_id is None else maybe_id])
for name, field in fields.items():
if ((opt.data_type == "audio") and (name == "src")):
continue
try:
f_iter = iter(field)
except TypeError:
f_iter = [(name, field)]
all_data = [getattr(ex, name, None)]
else:
all_data = getattr(ex, name)
for (sub_n, sub_f), fd in zip(
f_iter, all_data):
has_vocab = (sub_n == 'src' and
src_vocab is not None) or \
(sub_n == 'tgt' and
tgt_vocab is not None)
if (hasattr(sub_f, 'sequential')
and sub_f.sequential and not has_vocab):
val = fd
sub_sub_counter[sub_n].update(val)
if maybe_id:
shard_base = corpus_type + "_" + maybe_id
else:
shard_base = corpus_type
data_path = "{:s}.{:s}.{:d}.pt".\
format(opt.save_data, shard_base, i)
logger.info(" * saving %sth %s data shard to %s."
% (i, shard_base, data_path))
dataset.save(data_path)
del dataset.examples
gc.collect()
del dataset
gc.collect()
return sub_sub_counter
def maybe_load_vocab(corpus_type, counters, opt):
src_vocab = None
tgt_vocab = None
existing_fields = None
if corpus_type == "train":
if opt.src_vocab != "":
try:
logger.info("Using existing vocabulary...")
existing_fields = torch.load(opt.src_vocab)
except torch.serialization.pickle.UnpicklingError:
logger.info("Building vocab from text file...")
src_vocab, src_vocab_size = _load_vocab(
opt.src_vocab, "src", counters,
opt.src_words_min_frequency)
if opt.tgt_vocab != "":
tgt_vocab, tgt_vocab_size = _load_vocab(
opt.tgt_vocab, "tgt", counters,
opt.tgt_words_min_frequency)
return src_vocab, tgt_vocab, existing_fields
def build_save_dataset(corpus_type, fields, src_reader, tgt_reader,
align_reader, opt):
assert corpus_type in ['train', 'valid']
if corpus_type == 'train':
counters = defaultdict(Counter)
srcs = opt.train_src
tgts = opt.train_tgt
ids = opt.train_ids
aligns = opt.train_align
elif corpus_type == 'valid':
counters = None
srcs = [opt.valid_src]
tgts = [opt.valid_tgt]
ids = [None]
aligns = [opt.valid_align]
src_vocab, tgt_vocab, existing_fields = maybe_load_vocab(
corpus_type, counters, opt)
existing_shards = check_existing_pt_files(
opt, corpus_type, ids, existing_fields)
# every corpus has shards, no new one
if existing_shards == ids and not opt.overwrite:
return
def shard_iterator(srcs, tgts, ids, aligns, existing_shards,
existing_fields, corpus_type, opt):
"""
Builds a single iterator yielding every shard of every corpus.
"""
for src, tgt, maybe_id, maybe_align in zip(srcs, tgts, ids, aligns):
if maybe_id in existing_shards:
if opt.overwrite:
logger.warning("Overwrite shards for corpus {}"
.format(maybe_id))
else:
if corpus_type == "train":
assert existing_fields is not None,\
("A 'vocab.pt' file should be passed to "
"`-src_vocab` when adding a corpus to "
"a set of already existing shards.")
logger.warning("Ignore corpus {} because "
"shards already exist"
.format(maybe_id))
continue
if ((corpus_type == "train" or opt.filter_valid)
and tgt is not None):
filter_pred = partial(
inputters.filter_example,
use_src_len=opt.data_type == "text",
max_src_len=opt.src_seq_length,
max_tgt_len=opt.tgt_seq_length)
else:
filter_pred = None
src_shards = split_corpus(src, opt.shard_size)
tgt_shards = split_corpus(tgt, opt.shard_size)
align_shards = split_corpus(maybe_align, opt.shard_size)
for i, (ss, ts, a_s) in enumerate(
zip(src_shards, tgt_shards, align_shards)):
yield (i, (ss, ts, a_s, maybe_id, filter_pred))
shard_iter = shard_iterator(srcs, tgts, ids, aligns, existing_shards,
existing_fields, corpus_type, opt)
with Pool(opt.num_threads) as p:
dataset_params = (corpus_type, fields, src_reader, tgt_reader,
align_reader, opt, existing_fields,
src_vocab, tgt_vocab)
func = partial(process_one_shard, dataset_params)
for sub_counter in p.imap(func, shard_iter):
if sub_counter is not None:
for key, value in sub_counter.items():
counters[key].update(value)
if corpus_type == "train":
vocab_path = opt.save_data + '.vocab.pt'
new_fields = _build_fields_vocab(
fields, counters, opt.data_type,
opt.share_vocab, opt.vocab_size_multiple,
opt.src_vocab_size, opt.src_words_min_frequency,
opt.tgt_vocab_size, opt.tgt_words_min_frequency,
subword_prefix=opt.subword_prefix,
subword_prefix_is_joiner=opt.subword_prefix_is_joiner)
if existing_fields is None:
fields = new_fields
else:
fields = existing_fields
if old_style_vocab(fields):
fields = load_old_vocab(
fields, opt.data_type, dynamic_dict=opt.dynamic_dict)
# patch corpus_id
if fields.get("corpus_id", False):
fields["corpus_id"].vocab = new_fields["corpus_id"].vocab_cls(
counters["corpus_id"])
torch.save(fields, vocab_path)
def build_save_vocab(train_dataset, fields, opt):
fields = inputters.build_vocab(
train_dataset, fields, opt.data_type, opt.share_vocab,
opt.src_vocab, opt.src_vocab_size, opt.src_words_min_frequency,
opt.tgt_vocab, opt.tgt_vocab_size, opt.tgt_words_min_frequency,
vocab_size_multiple=opt.vocab_size_multiple
)
vocab_path = opt.save_data + '.vocab.pt'
torch.save(fields, vocab_path)
def count_features(path):
"""
path: location of a corpus file with whitespace-delimited tokens and
-delimited features within the token
returns: the number of features in the dataset
"""
with codecs.open(path, "r", "utf-8") as f:
first_tok = f.readline().split(None, 1)[0]
return len(first_tok.split(u"")) - 1
def preprocess(opt):
ArgumentParser.validate_preprocess_args(opt)
torch.manual_seed(opt.seed)
init_logger(opt.log_file)
logger.info("Extracting features...")
src_nfeats = 0
tgt_nfeats = 0
for src, tgt in zip(opt.train_src, opt.train_tgt):
src_nfeats += count_features(src) if opt.data_type == 'text' \
else 0
tgt_nfeats += count_features(tgt) # tgt always text so far
logger.info(" * number of source features: %d." % src_nfeats)
logger.info(" * number of target features: %d." % tgt_nfeats)
logger.info("Building `Fields` object...")
fields = inputters.get_fields(
opt.data_type,
src_nfeats,
tgt_nfeats,
dynamic_dict=opt.dynamic_dict,
with_align=opt.train_align[0] is not None,
src_truncate=opt.src_seq_length_trunc,
tgt_truncate=opt.tgt_seq_length_trunc)
src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
tgt_reader = inputters.str2reader["text"].from_opt(opt)
align_reader = inputters.str2reader["text"].from_opt(opt)
logger.info("Building & saving training data...")
build_save_dataset(
'train', fields, src_reader, tgt_reader, align_reader, opt)
if opt.valid_src and opt.valid_tgt:
logger.info("Building & saving validation data...")
build_save_dataset(
'valid', fields, src_reader, tgt_reader, align_reader, opt)
def _get_parser():
parser = ArgumentParser(description='preprocess.py')
opts.config_opts(parser)
opts.preprocess_opts(parser)
return parser
def main():
parser = _get_parser()
opt = parser.parse_args()
preprocess(opt)
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,59 @@
#!/usr/bin/env python
import argparse
import torch
def get_ctranslate2_model_spec(opt):
"""Creates a CTranslate2 model specification from the model options."""
with_relative_position = getattr(opt, "max_relative_positions", 0) > 0
is_ct2_compatible = (
opt.encoder_type == "transformer"
and opt.decoder_type == "transformer"
and getattr(opt, "self_attn_type", "scaled-dot") == "scaled-dot"
and ((opt.position_encoding and not with_relative_position)
or (with_relative_position and not opt.position_encoding)))
if not is_ct2_compatible:
return None
import ctranslate2
num_heads = getattr(opt, "heads", 8)
return ctranslate2.specs.TransformerSpec(
(opt.enc_layers, opt.dec_layers),
num_heads,
with_relative_position=with_relative_position)
def main():
parser = argparse.ArgumentParser(
description="Release an OpenNMT-py model for inference")
parser.add_argument("--model", "-m",
help="The model path", required=True)
parser.add_argument("--output", "-o",
help="The output path", required=True)
parser.add_argument("--format",
choices=["pytorch", "ctranslate2"],
default="pytorch",
help="The format of the released model")
parser.add_argument("--quantization", "-q",
choices=["int8", "int16"],
default=None,
help="Quantization type for CT2 model.")
opt = parser.parse_args()
model = torch.load(opt.model)
if opt.format == "pytorch":
model["optim"] = None
torch.save(model, opt.output)
elif opt.format == "ctranslate2":
model_spec = get_ctranslate2_model_spec(model["opt"])
if model_spec is None:
raise ValueError("This model is not supported by CTranslate2. Go "
"to https://github.com/OpenNMT/CTranslate2 for "
"more information on supported models.")
import ctranslate2
converter = ctranslate2.converters.OpenNMTPyConverter(opt.model)
converter.convert(opt.output, model_spec, force=True,
quantization=opt.quantization)
if __name__ == "__main__":
main()

157
opennmt/onmt/bin/server.py Executable file
Просмотреть файл

@ -0,0 +1,157 @@
#!/usr/bin/env python
import configargparse
from flask import Flask, jsonify, request
from waitress import serve
from onmt.translate import TranslationServer, ServerModelError
import logging
from logging.handlers import RotatingFileHandler
STATUS_OK = "ok"
STATUS_ERROR = "error"
def start(config_file,
url_root="./translator",
host="0.0.0.0",
port=5000,
debug=False):
def prefix_route(route_function, prefix='', mask='{0}{1}'):
def newroute(route, *args, **kwargs):
return route_function(mask.format(prefix, route), *args, **kwargs)
return newroute
if debug:
logger = logging.getLogger("main")
log_format = logging.Formatter(
"[%(asctime)s %(levelname)s] %(message)s")
file_handler = RotatingFileHandler(
"debug_requests.log",
maxBytes=1000000, backupCount=10)
file_handler.setFormatter(log_format)
logger.addHandler(file_handler)
app = Flask(__name__)
app.route = prefix_route(app.route, url_root)
translation_server = TranslationServer()
translation_server.start(config_file)
@app.route('/models', methods=['GET'])
def get_models():
out = translation_server.list_models()
return jsonify(out)
@app.route('/health', methods=['GET'])
def health():
out = {}
out['status'] = STATUS_OK
return jsonify(out)
@app.route('/clone_model/<int:model_id>', methods=['POST'])
def clone_model(model_id):
out = {}
data = request.get_json(force=True)
timeout = -1
if 'timeout' in data:
timeout = data['timeout']
del data['timeout']
opt = data.get('opt', None)
try:
model_id, load_time = translation_server.clone_model(
model_id, opt, timeout)
except ServerModelError as e:
out['status'] = STATUS_ERROR
out['error'] = str(e)
else:
out['status'] = STATUS_OK
out['model_id'] = model_id
out['load_time'] = load_time
return jsonify(out)
@app.route('/unload_model/<int:model_id>', methods=['GET'])
def unload_model(model_id):
out = {"model_id": model_id}
try:
translation_server.unload_model(model_id)
out['status'] = STATUS_OK
except Exception as e:
out['status'] = STATUS_ERROR
out['error'] = str(e)
return jsonify(out)
@app.route('/translate', methods=['POST'])
def translate():
inputs = request.get_json(force=True)
if debug:
logger.info(inputs)
out = {}
try:
trans, scores, n_best, _, aligns = translation_server.run(inputs)
assert len(trans) == len(inputs) * n_best
assert len(scores) == len(inputs) * n_best
assert len(aligns) == len(inputs) * n_best
out = [[] for _ in range(n_best)]
for i in range(len(trans)):
response = {"src": inputs[i // n_best]['src'], "tgt": trans[i],
"n_best": n_best, "pred_score": scores[i]}
if aligns[i][0] is not None:
response["align"] = aligns[i]
out[i % n_best].append(response)
except ServerModelError as e:
model_id = inputs[0].get("id")
if debug:
logger.warning("Unload model #{} "
"because of an error".format(model_id))
translation_server.models[model_id].unload()
out['error'] = str(e)
out['status'] = STATUS_ERROR
if debug:
logger.info(out)
return jsonify(out)
@app.route('/to_cpu/<int:model_id>', methods=['GET'])
def to_cpu(model_id):
out = {'model_id': model_id}
translation_server.models[model_id].to_cpu()
out['status'] = STATUS_OK
return jsonify(out)
@app.route('/to_gpu/<int:model_id>', methods=['GET'])
def to_gpu(model_id):
out = {'model_id': model_id}
translation_server.models[model_id].to_gpu()
out['status'] = STATUS_OK
return jsonify(out)
serve(app, host=host, port=port)
def _get_parser():
parser = configargparse.ArgumentParser(
config_file_parser_class=configargparse.YAMLConfigFileParser,
description="OpenNMT-py REST Server")
parser.add_argument("--ip", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default="5000")
parser.add_argument("--url_root", type=str, default="/translator")
parser.add_argument("--debug", "-d", action="store_true")
parser.add_argument("--config", "-c", type=str,
default="./available_models/conf.json")
return parser
def main():
parser = _get_parser()
args = parser.parse_args()
start(args.config, url_root=args.url_root, host=args.ip, port=args.port,
debug=args.debug)
if __name__ == "__main__":
main()

213
opennmt/onmt/bin/train.py Executable file
Просмотреть файл

@ -0,0 +1,213 @@
#!/usr/bin/env python
"""Train models."""
import os
import signal
import torch
import onmt.opts as opts
import onmt.utils.distributed
from onmt.utils.misc import set_random_seed
from onmt.utils.logging import init_logger, logger
from onmt.train_single import main as single_main
from onmt.utils.parse import ArgumentParser
from onmt.inputters.inputter import build_dataset_iter, patch_fields, \
load_old_vocab, old_style_vocab, build_dataset_iter_multiple
from itertools import cycle
def train(opt):
ArgumentParser.validate_train_opts(opt)
ArgumentParser.update_model_opts(opt)
ArgumentParser.validate_model_opts(opt)
set_random_seed(opt.seed, False)
# Load checkpoint if we resume from a previous training.
if opt.train_from:
logger.info('Loading checkpoint from %s' % opt.train_from)
checkpoint = torch.load(opt.train_from,
map_location=lambda storage, loc: storage)
logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
vocab = checkpoint['vocab']
else:
vocab = torch.load(opt.data + '.vocab.pt')
# check for code where vocab is saved instead of fields
# (in the future this will be done in a smarter way)
if old_style_vocab(vocab):
fields = load_old_vocab(
vocab, opt.model_type, dynamic_dict=opt.copy_attn)
else:
fields = vocab
# patch for fields that may be missing in old data/model
patch_fields(opt, fields)
if len(opt.data_ids) > 1:
train_shards = []
for train_id in opt.data_ids:
shard_base = "train_" + train_id
train_shards.append(shard_base)
train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
else:
if opt.data_ids[0] is not None:
shard_base = "train_" + opt.data_ids[0]
else:
shard_base = "train"
train_iter = build_dataset_iter(shard_base, fields, opt)
nb_gpu = len(opt.gpu_ranks)
if opt.world_size > 1:
queues = []
mp = torch.multiprocessing.get_context('spawn')
semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Train with multiprocessing.
procs = []
for device_id in range(nb_gpu):
q = mp.Queue(opt.queue_size)
queues += [q]
procs.append(mp.Process(target=run, args=(
opt, device_id, error_queue, q, semaphore), daemon=True))
procs[device_id].start()
logger.info(" Starting process pid: %d " % procs[device_id].pid)
error_handler.add_child(procs[device_id].pid)
producer = mp.Process(target=batch_producer,
args=(train_iter, queues, semaphore, opt,),
daemon=True)
producer.start()
error_handler.add_child(producer.pid)
for p in procs:
p.join()
producer.terminate()
elif nb_gpu == 1: # case 1 GPU only
single_main(opt, 0)
else: # case only CPU
single_main(opt, -1)
def batch_producer(generator_to_serve, queues, semaphore, opt):
init_logger(opt.log_file)
set_random_seed(opt.seed, False)
# generator_to_serve = iter(generator_to_serve)
def pred(x):
"""
Filters batches that belong only
to gpu_ranks of current node
"""
for rank in opt.gpu_ranks:
if x[0] % opt.world_size == rank:
return True
generator_to_serve = filter(
pred, enumerate(generator_to_serve))
def next_batch(device_id):
new_batch = next(generator_to_serve)
semaphore.acquire()
return new_batch[1]
b = next_batch(0)
for device_id, q in cycle(enumerate(queues)):
b.dataset = None
if isinstance(b.src, tuple):
b.src = tuple([_.to(torch.device(device_id))
for _ in b.src])
else:
b.src = b.src.to(torch.device(device_id))
b.tgt = b.tgt.to(torch.device(device_id))
b.indices = b.indices.to(torch.device(device_id))
b.alignment = b.alignment.to(torch.device(device_id)) \
if hasattr(b, 'alignment') else None
b.src_map = b.src_map.to(torch.device(device_id)) \
if hasattr(b, 'src_map') else None
b.align = b.align.to(torch.device(device_id)) \
if hasattr(b, 'align') else None
b.corpus_id = b.corpus_id.to(torch.device(device_id)) \
if hasattr(b, 'corpus_id') else None
# hack to dodge unpicklable `dict_keys`
b.fields = list(b.fields)
q.put(b)
b = next_batch(device_id)
def run(opt, device_id, error_queue, batch_queue, semaphore):
""" run process """
try:
gpu_rank = onmt.utils.distributed.multi_init(opt, device_id)
if gpu_rank != opt.gpu_ranks[device_id]:
raise AssertionError("An error occurred in \
Distributed initialization")
single_main(opt, device_id, batch_queue, semaphore)
except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback
error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc()))
class ErrorHandler(object):
"""A class that listens for exceptions in children processes and propagates
the tracebacks to the parent process."""
def __init__(self, error_queue):
""" init error handler """
import signal
import threading
self.error_queue = error_queue
self.children_pids = []
self.error_thread = threading.Thread(
target=self.error_listener, daemon=True)
self.error_thread.start()
signal.signal(signal.SIGUSR1, self.signal_handler)
def add_child(self, pid):
""" error handler """
self.children_pids.append(pid)
def error_listener(self):
""" error listener """
(rank, original_trace) = self.error_queue.get()
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)
def signal_handler(self, signalnum, stackframe):
""" signal handler """
for pid in self.children_pids:
os.kill(pid, signal.SIGINT) # kill children processes
(rank, original_trace) = self.error_queue.get()
msg = """\n\n-- Tracebacks above this line can probably
be ignored --\n\n"""
msg += original_trace
raise Exception(msg)
def _get_parser():
parser = ArgumentParser(description='train.py')
opts.config_opts(parser)
opts.model_opts(parser)
opts.train_opts(parser)
return parser
def main():
parser = _get_parser()
opt = parser.parse_args()
train(opt)
if __name__ == "__main__":
main()

52
opennmt/onmt/bin/translate.py Executable file
Просмотреть файл

@ -0,0 +1,52 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from onmt.utils.logging import init_logger
from onmt.utils.misc import split_corpus
from onmt.translate.translator import build_translator
import onmt.opts as opts
from onmt.utils.parse import ArgumentParser
def translate(opt):
ArgumentParser.validate_translate_opts(opt)
logger = init_logger(opt.log_file)
translator = build_translator(opt, report_score=True)
src_shards = split_corpus(opt.src, opt.shard_size)
tgt_shards = split_corpus(opt.tgt, opt.shard_size)
shard_pairs = zip(src_shards, tgt_shards)
for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
logger.info("Translating shard %d." % i)
translator.translate(
src=src_shard,
tgt=tgt_shard,
src_dir=opt.src_dir,
batch_size=opt.batch_size,
batch_type=opt.batch_type,
attn_debug=opt.attn_debug,
align_debug=opt.align_debug
)
def _get_parser():
parser = ArgumentParser(description='translate.py')
opts.config_opts(parser)
opts.translate_opts(parser)
return parser
def main():
parser = _get_parser()
opt = parser.parse_args()
translate(opt)
if __name__ == "__main__":
main()

Просмотреть файл

@ -190,7 +190,8 @@ class RNNDecoderBase(DecoderBase):
self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"]) self.state["hidden"] = tuple(h.detach() for h in self.state["hidden"])
self.state["input_feed"] = self.state["input_feed"].detach() self.state["input_feed"] = self.state["input_feed"].detach()
def forward(self, tgt, memory_bank, memory_lengths=None, step=None): def forward(self, tgt, memory_bank, memory_lengths=None, step=None,
**kwargs):
""" """
Args: Args:
tgt (LongTensor): sequences of padded tokens tgt (LongTensor): sequences of padded tokens

Просмотреть файл

@ -51,7 +51,8 @@ class EnsembleDecoder(DecoderBase):
super(EnsembleDecoder, self).__init__(attentional) super(EnsembleDecoder, self).__init__(attentional)
self.model_decoders = model_decoders self.model_decoders = model_decoders
def forward(self, tgt, memory_bank, memory_lengths=None, step=None): def forward(self, tgt, memory_bank, memory_lengths=None, step=None,
**kwargs):
"""See :func:`onmt.decoders.decoder.DecoderBase.forward()`.""" """See :func:`onmt.decoders.decoder.DecoderBase.forward()`."""
# Memory_lengths is a single tensor shared between all models. # Memory_lengths is a single tensor shared between all models.
# This assumption will not hold if Translator is modified # This assumption will not hold if Translator is modified
@ -60,7 +61,7 @@ class EnsembleDecoder(DecoderBase):
dec_outs, attns = zip(*[ dec_outs, attns = zip(*[
model_decoder( model_decoder(
tgt, memory_bank[i], tgt, memory_bank[i],
memory_lengths=memory_lengths, step=step) memory_lengths=memory_lengths, step=step, **kwargs)
for i, model_decoder in enumerate(self.model_decoders)]) for i, model_decoder in enumerate(self.model_decoders)])
mean_attns = self.combine_attns(attns) mean_attns = self.combine_attns(attns)
return EnsembleDecoderOutput(dec_outs), mean_attns return EnsembleDecoderOutput(dec_outs), mean_attns

Просмотреть файл

@ -12,25 +12,51 @@ from onmt.utils.misc import sequence_mask
class TransformerDecoderLayer(nn.Module): class TransformerDecoderLayer(nn.Module):
""" """Transformer Decoder layer block in Pre-Norm style.
Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
providing better converge speed and performance. This is also the actual
implementation in tensor2tensor and also avalable in fairseq.
See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
.. mermaid::
graph LR
%% "*SubLayer" can be self-attn, src-attn or feed forward block
A(input) --> B[Norm]
B --> C["*SubLayer"]
C --> D[Drop]
D --> E((+))
A --> E
E --> F(out)
Args: Args:
d_model (int): the dimension of keys/values/queries in d_model (int): the dimension of keys/values/queries in
:class:`MultiHeadedAttention`, also the input size of :class:`MultiHeadedAttention`, also the input size of
the first-layer of the :class:`PositionwiseFeedForward`. the first-layer of the :class:`PositionwiseFeedForward`.
heads (int): the number of heads for MultiHeadedAttention. heads (int): the number of heads for MultiHeadedAttention.
d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`. d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`.
dropout (float): dropout probability. dropout (float): dropout in residual, self-attn(dot) and feed-forward
attention_dropout (float): dropout in context_attn (and self-attn(avg))
self_attn_type (string): type of self-attention scaled-dot, average self_attn_type (string): type of self-attention scaled-dot, average
max_relative_positions (int):
Max distance between inputs in relative positions representations
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
full_context_alignment (bool):
whether enable an extra full context decoder forward for alignment
alignment_heads (int):
N. of cross attention heads to use for alignment guiding
""" """
def __init__(self, d_model, heads, d_ff, dropout, attention_dropout, def __init__(self, d_model, heads, d_ff, dropout, attention_dropout,
self_attn_type="scaled-dot", max_relative_positions=0, self_attn_type="scaled-dot", max_relative_positions=0,
aan_useffn=False): aan_useffn=False, full_context_alignment=False,
alignment_heads=0):
super(TransformerDecoderLayer, self).__init__() super(TransformerDecoderLayer, self).__init__()
if self_attn_type == "scaled-dot": if self_attn_type == "scaled-dot":
self.self_attn = MultiHeadedAttention( self.self_attn = MultiHeadedAttention(
heads, d_model, dropout=dropout, heads, d_model, dropout=attention_dropout,
max_relative_positions=max_relative_positions) max_relative_positions=max_relative_positions)
elif self_attn_type == "average": elif self_attn_type == "average":
self.self_attn = AverageAttention(d_model, self.self_attn = AverageAttention(d_model,
@ -43,54 +69,106 @@ class TransformerDecoderLayer(nn.Module):
self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
self.drop = nn.Dropout(dropout) self.drop = nn.Dropout(dropout)
self.full_context_alignment = full_context_alignment
self.alignment_heads = alignment_heads
def forward(self, *args, **kwargs):
""" Extend `_forward` for (possibly) multiple decoder pass:
Always a default (future masked) decoder forward pass,
Possibly a second future aware decoder pass for joint learn
full context alignement, :cite:`garg2019jointly`.
def forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
layer_cache=None, step=None):
"""
Args: Args:
inputs (FloatTensor): ``(batch_size, 1, model_dim)`` * All arguments of _forward.
with_align (bool): whether return alignment attention.
Returns:
(FloatTensor, FloatTensor, FloatTensor or None):
* output ``(batch_size, T, model_dim)``
* top_attn ``(batch_size, T, src_len)``
* attn_align ``(batch_size, T, src_len)`` or None
"""
with_align = kwargs.pop('with_align', False)
output, attns = self._forward(*args, **kwargs)
top_attn = attns[:, 0, :, :].contiguous()
attn_align = None
if with_align:
if self.full_context_alignment:
# return _, (B, Q_len, K_len)
_, attns = self._forward(*args, **kwargs, future=True)
if self.alignment_heads > 0:
attns = attns[:, :self.alignment_heads, :, :].contiguous()
# layer average attention across heads, get ``(B, Q, K)``
# Case 1: no full_context, no align heads -> layer avg baseline
# Case 2: no full_context, 1 align heads -> guided align
# Case 3: full_context, 1 align heads -> full cte guided align
attn_align = attns.mean(dim=1)
return output, top_attn, attn_align
def _forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
layer_cache=None, step=None, future=False):
""" A naive forward pass for transformer decoder.
# T: could be 1 in the case of stepwise decoding or tgt_len
Args:
inputs (FloatTensor): ``(batch_size, T, model_dim)``
memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)`` memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)``
src_pad_mask (LongTensor): ``(batch_size, 1, src_len)`` src_pad_mask (LongTensor): ``(batch_size, 1, src_len)``
tgt_pad_mask (LongTensor): ``(batch_size, 1, 1)`` tgt_pad_mask (LongTensor): ``(batch_size, 1, T)``
layer_cache (dict or None): cached layer info when stepwise decode
step (int or None): stepwise decoding counter
future (bool): If set True, do not apply future_mask.
Returns: Returns:
(FloatTensor, FloatTensor): (FloatTensor, FloatTensor):
* output ``(batch_size, 1, model_dim)`` * output ``(batch_size, T, model_dim)``
* attn ``(batch_size, 1, src_len)`` * attns ``(batch_size, head, T, src_len)``
""" """
dec_mask = None dec_mask = None
if step is None: if step is None:
tgt_len = tgt_pad_mask.size(-1) tgt_len = tgt_pad_mask.size(-1)
if not future: # apply future_mask, result mask in (B, T, T)
future_mask = torch.ones( future_mask = torch.ones(
[tgt_len, tgt_len], [tgt_len, tgt_len],
device=tgt_pad_mask.device, device=tgt_pad_mask.device,
dtype=torch.uint8) dtype=torch.uint8)
future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
# BoolTensor was introduced in pytorch 1.2
try:
future_mask = future_mask.bool()
except AttributeError:
pass
dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
else: # only mask padding, result mask in (B, 1, T)
dec_mask = tgt_pad_mask
input_norm = self.layer_norm_1(inputs) input_norm = self.layer_norm_1(inputs)
if isinstance(self.self_attn, MultiHeadedAttention): if isinstance(self.self_attn, MultiHeadedAttention):
query, attn = self.self_attn(input_norm, input_norm, input_norm, query, _ = self.self_attn(input_norm, input_norm, input_norm,
mask=dec_mask, mask=dec_mask,
layer_cache=layer_cache, layer_cache=layer_cache,
attn_type="self") attn_type="self")
elif isinstance(self.self_attn, AverageAttention): elif isinstance(self.self_attn, AverageAttention):
query, attn = self.self_attn(input_norm, mask=dec_mask, query, _ = self.self_attn(input_norm, mask=dec_mask,
layer_cache=layer_cache, step=step) layer_cache=layer_cache, step=step)
query = self.drop(query) + inputs query = self.drop(query) + inputs
query_norm = self.layer_norm_2(query) query_norm = self.layer_norm_2(query)
mid, attn = self.context_attn(memory_bank, memory_bank, query_norm, mid, attns = self.context_attn(memory_bank, memory_bank, query_norm,
mask=src_pad_mask, mask=src_pad_mask,
layer_cache=layer_cache, layer_cache=layer_cache,
attn_type="context") attn_type="context")
output = self.feed_forward(self.drop(mid) + query) output = self.feed_forward(self.drop(mid) + query)
return output, attn return output, attns
def update_dropout(self, dropout, attention_dropout): def update_dropout(self, dropout, attention_dropout):
self.self_attn.update_dropout(attention_dropout) self.self_attn.update_dropout(attention_dropout)
@ -124,14 +202,25 @@ class TransformerDecoder(DecoderBase):
d_ff (int): size of the inner FF layer d_ff (int): size of the inner FF layer
copy_attn (bool): if using a separate copy attention copy_attn (bool): if using a separate copy attention
self_attn_type (str): type of self-attention scaled-dot, average self_attn_type (str): type of self-attention scaled-dot, average
dropout (float): dropout parameters dropout (float): dropout in residual, self-attn(dot) and feed-forward
attention_dropout (float): dropout in context_attn (and self-attn(avg))
embeddings (onmt.modules.Embeddings): embeddings (onmt.modules.Embeddings):
embeddings to use, should have positional encodings embeddings to use, should have positional encodings
max_relative_positions (int):
Max distance between inputs in relative positions representations
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
full_context_alignment (bool):
whether enable an extra full context decoder forward for alignment
alignment_layer (int): N° Layer to supervise with for alignment guiding
alignment_heads (int):
N. of cross attention heads to use for alignment guiding
""" """
def __init__(self, num_layers, d_model, heads, d_ff, def __init__(self, num_layers, d_model, heads, d_ff,
copy_attn, self_attn_type, dropout, attention_dropout, copy_attn, self_attn_type, dropout, attention_dropout,
embeddings, max_relative_positions, aan_useffn): embeddings, max_relative_positions, aan_useffn,
full_context_alignment, alignment_layer,
alignment_heads):
super(TransformerDecoder, self).__init__() super(TransformerDecoder, self).__init__()
self.embeddings = embeddings self.embeddings = embeddings
@ -143,7 +232,9 @@ class TransformerDecoder(DecoderBase):
[TransformerDecoderLayer(d_model, heads, d_ff, dropout, [TransformerDecoderLayer(d_model, heads, d_ff, dropout,
attention_dropout, self_attn_type=self_attn_type, attention_dropout, self_attn_type=self_attn_type,
max_relative_positions=max_relative_positions, max_relative_positions=max_relative_positions,
aan_useffn=aan_useffn) aan_useffn=aan_useffn,
full_context_alignment=full_context_alignment,
alignment_heads=alignment_heads)
for i in range(num_layers)]) for i in range(num_layers)])
# previously, there was a GlobalAttention module here for copy # previously, there was a GlobalAttention module here for copy
@ -152,6 +243,8 @@ class TransformerDecoder(DecoderBase):
self._copy = copy_attn self._copy = copy_attn
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.alignment_layer = alignment_layer
@classmethod @classmethod
def from_opt(cls, opt, embeddings): def from_opt(cls, opt, embeddings):
"""Alternate constructor.""" """Alternate constructor."""
@ -167,7 +260,10 @@ class TransformerDecoder(DecoderBase):
is list else opt.dropout, is list else opt.dropout,
embeddings, embeddings,
opt.max_relative_positions, opt.max_relative_positions,
opt.aan_useffn) opt.aan_useffn,
opt.full_context_alignment,
opt.alignment_layer,
alignment_heads=opt.alignment_heads)
def init_state(self, src, memory_bank, enc_hidden): def init_state(self, src, memory_bank, enc_hidden):
"""Initialize decoder state.""" """Initialize decoder state."""
@ -209,16 +305,22 @@ class TransformerDecoder(DecoderBase):
src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1) src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt] tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]
with_align = kwargs.pop('with_align', False)
attn_aligns = []
for i, layer in enumerate(self.transformer_layers): for i, layer in enumerate(self.transformer_layers):
layer_cache = self.state["cache"]["layer_{}".format(i)] \ layer_cache = self.state["cache"]["layer_{}".format(i)] \
if step is not None else None if step is not None else None
output, attn = layer( output, attn, attn_align = layer(
output, output,
src_memory_bank, src_memory_bank,
src_pad_mask, src_pad_mask,
tgt_pad_mask, tgt_pad_mask,
layer_cache=layer_cache, layer_cache=layer_cache,
step=step) step=step,
with_align=with_align)
if attn_align is not None:
attn_aligns.append(attn_align)
output = self.layer_norm(output) output = self.layer_norm(output)
dec_outs = output.transpose(0, 1).contiguous() dec_outs = output.transpose(0, 1).contiguous()
@ -227,6 +329,9 @@ class TransformerDecoder(DecoderBase):
attns = {"std": attn} attns = {"std": attn}
if self._copy: if self._copy:
attns["copy"] = attn attns["copy"] = attn
if with_align:
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
# TODO change the way attns is returned dict => list or tuple (onnx) # TODO change the way attns is returned dict => list or tuple (onnx)
return dec_outs, attns return dec_outs, attns

Просмотреть файл

@ -1,6 +1,7 @@
"""Module defining encoders.""" """Module defining encoders."""
from onmt.encoders.encoder import EncoderBase from onmt.encoders.encoder import EncoderBase
from onmt.encoders.transformer import TransformerEncoder from onmt.encoders.transformer import TransformerEncoder
from onmt.encoders.ggnn_encoder import GGNNEncoder
from onmt.encoders.rnn_encoder import RNNEncoder from onmt.encoders.rnn_encoder import RNNEncoder
from onmt.encoders.cnn_encoder import CNNEncoder from onmt.encoders.cnn_encoder import CNNEncoder
from onmt.encoders.mean_encoder import MeanEncoder from onmt.encoders.mean_encoder import MeanEncoder
@ -8,9 +9,9 @@ from onmt.encoders.audio_encoder import AudioEncoder
from onmt.encoders.image_encoder import ImageEncoder from onmt.encoders.image_encoder import ImageEncoder
str2enc = {"rnn": RNNEncoder, "brnn": RNNEncoder, "cnn": CNNEncoder, str2enc = {"ggnn": GGNNEncoder, "rnn": RNNEncoder, "brnn": RNNEncoder,
"transformer": TransformerEncoder, "img": ImageEncoder, "cnn": CNNEncoder, "transformer": TransformerEncoder,
"audio": AudioEncoder, "mean": MeanEncoder} "img": ImageEncoder, "audio": AudioEncoder, "mean": MeanEncoder}
__all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", __all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder",
"MeanEncoder", "str2enc"] "MeanEncoder", "str2enc"]

Просмотреть файл

@ -0,0 +1,279 @@
"""Define GGNN-based encoders."""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from onmt.encoders.encoder import EncoderBase
class GGNNAttrProxy(object):
"""
Translates index lookups into attribute lookups.
To implement some trick which able to use list of nn.Module in a nn.Module
see https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219/2
"""
def __init__(self, module, prefix):
self.module = module
self.prefix = prefix
def __getitem__(self, i):
return getattr(self.module, self.prefix + str(i))
class GGNNPropogator(nn.Module):
"""
Gated Propogator for GGNN
Using LSTM gating mechanism
"""
def __init__(self, state_dim, n_node, n_edge_types):
super(GGNNPropogator, self).__init__()
self.n_node = n_node
self.n_edge_types = n_edge_types
self.reset_gate = nn.Sequential(
nn.Linear(state_dim*3, state_dim),
nn.Sigmoid()
)
self.update_gate = nn.Sequential(
nn.Linear(state_dim*3, state_dim),
nn.Sigmoid()
)
self.tansform = nn.Sequential(
nn.Linear(state_dim*3, state_dim),
nn.LeakyReLU()
)
def forward(self, state_in, state_out, state_cur, edges, nodes):
edges_in = edges[:, :, :nodes*self.n_edge_types]
edges_out = edges[:, :, nodes*self.n_edge_types:]
a_in = torch.bmm(edges_in, state_in)
a_out = torch.bmm(edges_out, state_out)
a = torch.cat((a_in, a_out, state_cur), 2)
r = self.reset_gate(a)
z = self.update_gate(a)
joined_input = torch.cat((a_in, a_out, r * state_cur), 2)
h_hat = self.tansform(joined_input)
output = (1 - z) * state_cur + z * h_hat
return output
class GGNNEncoder(EncoderBase):
""" A gated graph neural network configured as an encoder.
Based on github.com/JamesChuanggg/ggnn.pytorch.git,
which is based on the paper "Gated Graph Sequence Neural Networks"
by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel.
Args:
rnn_type (str):
style of recurrent unit to use, one of [LSTM]
state_dim (int) : Number of state dimensions in nodes
n_edge_types (int) : Number of edge types
bidir_edges (bool): True if reverse edges should be autocreated
n_node (int) : Max nodes in graph
bridge_extra_node (bool): True indicates only 1st extra node
(after token listing) should be used for decoder init.
n_steps (int): Steps to advance graph encoder for stabilization
src_vocab (int): Path to source vocabulary
"""
def __init__(self, rnn_type, state_dim, bidir_edges,
n_edge_types, n_node, bridge_extra_node, n_steps, src_vocab):
super(GGNNEncoder, self).__init__()
self.state_dim = state_dim
self.n_edge_types = n_edge_types
self.n_node = n_node
self.n_steps = n_steps
self.bidir_edges = bidir_edges
self.bridge_extra_node = bridge_extra_node
for i in range(self.n_edge_types):
# incoming and outgoing edge embedding
in_fc = nn.Linear(self.state_dim, self.state_dim)
out_fc = nn.Linear(self.state_dim, self.state_dim)
self.add_module("in_{}".format(i), in_fc)
self.add_module("out_{}".format(i), out_fc)
self.in_fcs = GGNNAttrProxy(self, "in_")
self.out_fcs = GGNNAttrProxy(self, "out_")
# Find vocab data for tree builting
f = open(src_vocab, "r")
idx = 0
self.COMMA = -1
self.DELIMITER = -1
self.idx2num = []
for ln in f:
ln = ln.strip('\n')
if ln == ",":
self.COMMA = idx
if ln == "<EOT>":
self.DELIMITER = idx
if ln.isdigit():
self.idx2num.append(int(ln))
else:
self.idx2num.append(-1)
idx += 1
# Propogation Model
self.propogator = GGNNPropogator(self.state_dim, self.n_node,
self.n_edge_types)
self._initialization()
# Initialize the bridge layer
self._initialize_bridge(rnn_type, self.state_dim, 1)
@classmethod
def from_opt(cls, opt, embeddings):
"""Alternate constructor."""
return cls(
opt.rnn_type,
opt.state_dim,
opt.bidir_edges,
opt.n_edge_types,
opt.n_node,
opt.bridge_extra_node,
opt.n_steps,
opt.src_vocab)
def _initialization(self):
for m in self.modules():
if isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
def forward(self, src, lengths=None):
"""See :func:`EncoderBase.forward()`"""
self._check_args(src, lengths)
nodes = self.n_node
batch_size = src.size()[1]
first_extra = np.zeros(batch_size, dtype=np.int32)
prop_state = np.zeros((batch_size, nodes, self.state_dim),
dtype=np.int32)
edges = np.zeros((batch_size, nodes, nodes*self.n_edge_types*2),
dtype=np.int32)
npsrc = src[:, :, 0].cpu().data.numpy().astype(np.int32)
# Initialize graph using formatted input sequence
for i in range(batch_size):
tokens_done = False
# Number of flagged nodes defines node count for this sample
# (Nodes can have no flags on them, but must be in 'flags' list).
flags = 0
flags_done = False
edge = 0
source_node = -1
for j in range(len(npsrc)):
token = npsrc[j][i]
if not tokens_done:
if token == self.DELIMITER:
tokens_done = True
first_extra[i] = j
else:
prop_state[i][j][token] = 1
elif token == self.DELIMITER:
flags += 1
flags_done = True
assert flags <= nodes
elif not flags_done:
# The total number of integers in the vocab should allow
# for all features and edges to be defined.
if token == self.COMMA:
flags = 0
else:
num = self.idx2num[token]
if num >= 0:
prop_state[i][flags][num+self.DELIMITER] = 1
flags += 1
elif token == self.COMMA:
edge += 1
assert source_node == -1, 'Error in graph edge input'
assert (edge <= 2*self.n_edge_types and
(not self.bidir_edges or edge < self.n_edge_types))
else:
num = self.idx2num[token]
if source_node < 0:
source_node = num
else:
edges[i][source_node][num+nodes*edge] = 1
if self.bidir_edges:
edges[i][num][nodes*(edge+self.n_edge_types)
+ source_node] = 1
source_node = -1
if torch.cuda.is_available():
prop_state = torch.from_numpy(prop_state).float().to("cuda:0")
edges = torch.from_numpy(edges).float().to("cuda:0")
else:
prop_state = torch.from_numpy(prop_state).float()
edges = torch.from_numpy(edges).float()
for i_step in range(self.n_steps):
in_states = []
out_states = []
for i in range(self.n_edge_types):
in_states.append(self.in_fcs[i](prop_state))
out_states.append(self.out_fcs[i](prop_state))
in_states = torch.stack(in_states).transpose(0, 1).contiguous()
in_states = in_states.view(-1, nodes*self.n_edge_types,
self.state_dim)
out_states = torch.stack(out_states).transpose(0, 1).contiguous()
out_states = out_states.view(-1, nodes*self.n_edge_types,
self.state_dim)
prop_state = self.propogator(in_states, out_states, prop_state,
edges, nodes)
prop_state = prop_state.transpose(0, 1)
if self.bridge_extra_node:
# Use first extra node as only source for decoder init
join_state = prop_state[first_extra, torch.arange(batch_size)]
else:
# Average all nodes to get bridge input
join_state = prop_state.mean(0)
join_state = torch.stack((join_state, join_state,
join_state, join_state))
join_state = (join_state, join_state)
encoder_final = self._bridge(join_state)
return encoder_final, prop_state, lengths
def _initialize_bridge(self, rnn_type,
hidden_size,
num_layers):
# LSTM has hidden and cell state, other only one
number_of_states = 2 if rnn_type == "LSTM" else 1
# Total number of states
self.total_hidden_dim = hidden_size * num_layers
# Build a linear layer for each
self.bridge = nn.ModuleList([nn.Linear(self.total_hidden_dim,
self.total_hidden_dim,
bias=True)
for _ in range(number_of_states)])
def _bridge(self, hidden):
"""Forward hidden state through bridge."""
def bottle_hidden(linear, states):
"""
Transform from 3D to 2D, apply linear and return initial size
"""
size = states.size()
result = linear(states.view(-1, self.total_hidden_dim))
return F.leaky_relu(result).view(size)
if isinstance(hidden, tuple): # LSTM
outs = tuple([bottle_hidden(layer, hidden[ix])
for ix, layer in enumerate(self.bridge)])
else:
outs = bottle_hidden(self.bridge[0], hidden)
return outs

Просмотреть файл

@ -13,7 +13,6 @@ from onmt.inputters.audio_dataset import audio_sort_key, AudioDataReader
from onmt.inputters.vec_dataset import vec_sort_key, VecDataReader from onmt.inputters.vec_dataset import vec_sort_key, VecDataReader
from onmt.inputters.datareader_base import DataReaderBase from onmt.inputters.datareader_base import DataReaderBase
str2reader = { str2reader = {
"text": TextDataReader, "img": ImageDataReader, "audio": AudioDataReader, "text": TextDataReader, "img": ImageDataReader, "audio": AudioDataReader,
"vec": VecDataReader} "vec": VecDataReader}

Просмотреть файл

@ -108,7 +108,7 @@ class Dataset(TorchtextDataset):
""" """
def __init__(self, fields, readers, data, dirs, sort_key, def __init__(self, fields, readers, data, dirs, sort_key,
filter_pred=None): filter_pred=None, corpus_id=None):
self.sort_key = sort_key self.sort_key = sort_key
can_copy = 'src_map' in fields and 'alignment' in fields can_copy = 'src_map' in fields and 'alignment' in fields
@ -119,6 +119,10 @@ class Dataset(TorchtextDataset):
self.src_vocabs = [] self.src_vocabs = []
examples = [] examples = []
for ex_dict in starmap(_join_dicts, zip(*read_iters)): for ex_dict in starmap(_join_dicts, zip(*read_iters)):
if corpus_id is not None:
ex_dict["corpus_id"] = corpus_id
else:
ex_dict["corpus_id"] = "train"
if can_copy: if can_copy:
src_field = fields['src'] src_field = fields['src']
tgt_field = fields['tgt'] tgt_field = fields['tgt']
@ -152,3 +156,13 @@ class Dataset(TorchtextDataset):
if remove_fields: if remove_fields:
self.fields = [] self.fields = []
torch.save(self, path) torch.save(self, path)
@staticmethod
def config(fields):
readers, data, dirs = [], [], []
for name, field in fields:
if field["data"] is not None:
readers.append(field["reader"])
data.append((name, field["data"]))
dirs.append(field["dir"])
return readers, data, dirs

Просмотреть файл

@ -9,7 +9,7 @@ from itertools import chain, cycle
import torch import torch
import torchtext.data import torchtext.data
from torchtext.data import Field, RawField from torchtext.data import Field, RawField, LabelField
from torchtext.vocab import Vocab from torchtext.vocab import Vocab
from torchtext.data.utils import RandomShuffler from torchtext.data.utils import RandomShuffler
@ -58,6 +58,47 @@ def make_tgt(data, vocab):
return alignment return alignment
class AlignField(LabelField):
"""
Parse ['<src>-<tgt>', ...] into ['<src>','<tgt>', ...]
"""
def __init__(self, **kwargs):
kwargs['use_vocab'] = False
kwargs['preprocessing'] = parse_align_idx
super(AlignField, self).__init__(**kwargs)
def process(self, batch, device=None):
""" Turn a batch of align-idx to a sparse align idx Tensor"""
sparse_idx = []
for i, example in enumerate(batch):
for src, tgt in example:
# +1 for tgt side to keep coherent after "bos" padding,
# register ['N°_in_batch', 'tgt_id+1', 'src_id']
sparse_idx.append([i, tgt + 1, src])
align_idx = torch.tensor(sparse_idx, dtype=self.dtype, device=device)
return align_idx
def parse_align_idx(align_pharaoh):
"""
Parse Pharaoh alignment into [[<src>, <tgt>], ...]
"""
align_list = align_pharaoh.strip().split(' ')
flatten_align_idx = []
for align in align_list:
try:
src_idx, tgt_idx = align.split('-')
except ValueError:
logger.warning("{} in `{}`".format(align, align_pharaoh))
logger.warning("Bad alignement line exists. Please check file!")
raise
flatten_align_idx.append([int(src_idx), int(tgt_idx)])
return flatten_align_idx
def get_fields( def get_fields(
src_data_type, src_data_type,
n_src_feats, n_src_feats,
@ -66,6 +107,7 @@ def get_fields(
bos='<s>', bos='<s>',
eos='</s>', eos='</s>',
dynamic_dict=False, dynamic_dict=False,
with_align=False,
src_truncate=None, src_truncate=None,
tgt_truncate=None tgt_truncate=None
): ):
@ -84,6 +126,7 @@ def get_fields(
for tgt. for tgt.
dynamic_dict (bool): Whether or not to include source map and dynamic_dict (bool): Whether or not to include source map and
alignment fields. alignment fields.
with_align (bool): Whether or not to include word align.
src_truncate: Cut off src sequences beyond this (passed to src_truncate: Cut off src sequences beyond this (passed to
``src_data_type``'s data reader - see there for more details). ``src_data_type``'s data reader - see there for more details).
tgt_truncate: Cut off tgt sequences beyond this (passed to tgt_truncate: Cut off tgt sequences beyond this (passed to
@ -122,6 +165,9 @@ def get_fields(
indices = Field(use_vocab=False, dtype=torch.long, sequential=False) indices = Field(use_vocab=False, dtype=torch.long, sequential=False)
fields["indices"] = indices fields["indices"] = indices
corpus_ids = Field(use_vocab=True, sequential=False)
fields["corpus_id"] = corpus_ids
if dynamic_dict: if dynamic_dict:
src_map = Field( src_map = Field(
use_vocab=False, dtype=torch.float, use_vocab=False, dtype=torch.float,
@ -136,9 +182,20 @@ def get_fields(
postprocessing=make_tgt, sequential=False) postprocessing=make_tgt, sequential=False)
fields["alignment"] = align fields["alignment"] = align
if with_align:
word_align = AlignField()
fields["align"] = word_align
return fields return fields
def patch_fields(opt, fields):
dvocab = torch.load(opt.data + '.vocab.pt')
maybe_cid_field = dvocab.get('corpus_id', None)
if maybe_cid_field is not None:
fields.update({'corpus_id': maybe_cid_field})
def load_old_vocab(vocab, data_type="text", dynamic_dict=False): def load_old_vocab(vocab, data_type="text", dynamic_dict=False):
"""Update a legacy vocab/field format. """Update a legacy vocab/field format.
@ -317,7 +374,9 @@ def _build_fv_from_multifield(multifield, counters, build_fv_args,
def _build_fields_vocab(fields, counters, data_type, share_vocab, def _build_fields_vocab(fields, counters, data_type, share_vocab,
vocab_size_multiple, vocab_size_multiple,
src_vocab_size, src_words_min_frequency, src_vocab_size, src_words_min_frequency,
tgt_vocab_size, tgt_words_min_frequency): tgt_vocab_size, tgt_words_min_frequency,
subword_prefix="",
subword_prefix_is_joiner=False):
build_fv_args = defaultdict(dict) build_fv_args = defaultdict(dict)
build_fv_args["src"] = dict( build_fv_args["src"] = dict(
max_size=src_vocab_size, min_freq=src_words_min_frequency) max_size=src_vocab_size, min_freq=src_words_min_frequency)
@ -329,6 +388,11 @@ def _build_fields_vocab(fields, counters, data_type, share_vocab,
counters, counters,
build_fv_args, build_fv_args,
size_multiple=vocab_size_multiple if not share_vocab else 1) size_multiple=vocab_size_multiple if not share_vocab else 1)
if fields.get("corpus_id", False):
fields["corpus_id"].vocab = fields["corpus_id"].vocab_cls(
counters["corpus_id"])
if data_type == 'text': if data_type == 'text':
src_multifield = fields["src"] src_multifield = fields["src"]
_build_fv_from_multifield( _build_fv_from_multifield(
@ -336,6 +400,7 @@ def _build_fields_vocab(fields, counters, data_type, share_vocab,
counters, counters,
build_fv_args, build_fv_args,
size_multiple=vocab_size_multiple if not share_vocab else 1) size_multiple=vocab_size_multiple if not share_vocab else 1)
if share_vocab: if share_vocab:
# `tgt_vocab_size` is ignored when sharing vocabularies # `tgt_vocab_size` is ignored when sharing vocabularies
logger.info(" * merging src and tgt vocab...") logger.info(" * merging src and tgt vocab...")
@ -347,9 +412,38 @@ def _build_fields_vocab(fields, counters, data_type, share_vocab,
vocab_size_multiple=vocab_size_multiple) vocab_size_multiple=vocab_size_multiple)
logger.info(" * merged vocab size: %d." % len(src_field.vocab)) logger.info(" * merged vocab size: %d." % len(src_field.vocab))
build_noise_field(
src_multifield.base_field,
subword_prefix=subword_prefix,
is_joiner=subword_prefix_is_joiner)
return fields return fields
def build_noise_field(src_field, subword=True,
subword_prefix="", is_joiner=False,
sentence_breaks=[".", "?", "!"]):
"""In place add noise related fields i.e.:
- word_start
- end_of_sentence
"""
if subword:
def is_word_start(x): return (x.startswith(subword_prefix) ^ is_joiner)
sentence_breaks = [subword_prefix + t for t in sentence_breaks]
else:
def is_word_start(x): return True
vocab_size = len(src_field.vocab)
word_start_mask = torch.zeros([vocab_size]).bool()
end_of_sentence_mask = torch.zeros([vocab_size]).bool()
for i, t in enumerate(src_field.vocab.itos):
if is_word_start(t):
word_start_mask[i] = True
if t in sentence_breaks:
end_of_sentence_mask[i] = True
src_field.word_start_mask = word_start_mask
src_field.end_of_sentence_mask = end_of_sentence_mask
def build_vocab(train_dataset_files, fields, data_type, share_vocab, def build_vocab(train_dataset_files, fields, data_type, share_vocab,
src_vocab_path, src_vocab_size, src_words_min_frequency, src_vocab_path, src_vocab_size, src_words_min_frequency,
tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency, tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency,
@ -776,11 +870,14 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False):
to iterate over. We implement simple ordered iterator strategy here, to iterate over. We implement simple ordered iterator strategy here,
but more sophisticated strategy like curriculum learning is ok too. but more sophisticated strategy like curriculum learning is ok too.
""" """
dataset_glob = opt.data + '.' + corpus_type + '.[0-9]*.pt'
dataset_paths = list(sorted( dataset_paths = list(sorted(
glob.glob(opt.data + '.' + corpus_type + '.[0-9]*.pt'))) glob.glob(dataset_glob),
key=lambda p: int(p.split(".")[-2])))
if not dataset_paths: if not dataset_paths:
if is_train: if is_train:
raise ValueError('Training data %s not found' % opt.data) raise ValueError('Training data %s not found' % dataset_glob)
else: else:
return None return None
if multi: if multi:

Просмотреть файл

@ -190,6 +190,8 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None):
vocab_size = len(tgt_base_field.vocab) vocab_size = len(tgt_base_field.vocab)
pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token] pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token]
generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx) generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx)
if model_opt.share_decoder_embeddings:
generator.linear.weight = decoder.embeddings.word_lut.weight
# Load the model states from checkpoint or initialize them. # Load the model states from checkpoint or initialize them.
if checkpoint is not None: if checkpoint is not None:
@ -230,7 +232,8 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None):
model.generator = generator model.generator = generator
model.to(device) model.to(device)
if model_opt.model_dtype == 'fp16' and model_opt.optim == 'fusedadam':
model.half()
return model return model

Просмотреть файл

@ -2,5 +2,4 @@
from onmt.models.model_saver import build_model_saver, ModelSaver from onmt.models.model_saver import build_model_saver, ModelSaver
from onmt.models.model import NMTModel from onmt.models.model import NMTModel
__all__ = ["build_model_saver", "ModelSaver", __all__ = ["build_model_saver", "ModelSaver", "NMTModel"]
"NMTModel", "check_sru_requirement"]

Просмотреть файл

@ -17,7 +17,7 @@ class NMTModel(nn.Module):
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
def forward(self, src, tgt, lengths, bptt=False): def forward(self, src, tgt, lengths, bptt=False, with_align=False):
"""Forward propagate a `src` and `tgt` pair for training. """Forward propagate a `src` and `tgt` pair for training.
Possible initialized with a beginning decoder state. Possible initialized with a beginning decoder state.
@ -26,10 +26,13 @@ class NMTModel(nn.Module):
typically for inputs this will be a padded `LongTensor` typically for inputs this will be a padded `LongTensor`
of size ``(len, batch, features)``. However, may be an of size ``(len, batch, features)``. However, may be an
image or other generic input depending on encoder. image or other generic input depending on encoder.
tgt (LongTensor): A target sequence of size ``(tgt_len, batch)``. tgt (LongTensor): A target sequence passed to decoder.
Size ``(tgt_len, batch, features)``.
lengths(LongTensor): The src lengths, pre-padding ``(batch,)``. lengths(LongTensor): The src lengths, pre-padding ``(batch,)``.
bptt (Boolean): A flag indicating if truncated bptt is set. bptt (Boolean): A flag indicating if truncated bptt is set.
If reset then init_state If reset then init_state
with_align (Boolean): A flag indicating whether output alignment,
Only valid for transformer decoder.
Returns: Returns:
(FloatTensor, dict[str, FloatTensor]): (FloatTensor, dict[str, FloatTensor]):
@ -37,14 +40,15 @@ class NMTModel(nn.Module):
* decoder output ``(tgt_len, batch, hidden)`` * decoder output ``(tgt_len, batch, hidden)``
* dictionary attention dists of ``(tgt_len, batch, src_len)`` * dictionary attention dists of ``(tgt_len, batch, src_len)``
""" """
tgt = tgt[:-1] # exclude last target from inputs dec_in = tgt[:-1] # exclude last target from inputs
enc_state, memory_bank, lengths = self.encoder(src, lengths) enc_state, memory_bank, lengths = self.encoder(src, lengths)
if bptt is False: if bptt is False:
self.decoder.init_state(src, memory_bank, enc_state) self.decoder.init_state(src, memory_bank, enc_state)
dec_out, attns = self.decoder(tgt, memory_bank, dec_out, attns = self.decoder(dec_in, memory_bank,
memory_lengths=lengths) memory_lengths=lengths,
with_align=with_align)
return dec_out, attns return dec_out, attns
def update_dropout(self, dropout): def update_dropout(self, dropout):

Просмотреть файл

@ -1,6 +1,5 @@
import os import os
import torch import torch
import torch.nn as nn
from collections import deque from collections import deque
from onmt.utils.logging import logger from onmt.utils.logging import logger
@ -48,18 +47,20 @@ class ModelSaverBase(object):
if self.keep_checkpoint == 0 or step == self.last_saved_step: if self.keep_checkpoint == 0 or step == self.last_saved_step:
return return
if moving_average:
save_model = deepcopy(self.model)
for avg, param in zip(moving_average, save_model.parameters()):
param.data.copy_(avg.data)
else:
save_model = self.model save_model = self.model
if moving_average:
model_params_data = []
for avg, param in zip(moving_average, save_model.parameters()):
model_params_data.append(param.data)
param.data = avg.data
chkpt, chkpt_name = self._save(step, save_model) chkpt, chkpt_name = self._save(step, save_model)
self.last_saved_step = step self.last_saved_step = step
if moving_average: if moving_average:
del save_model for param_data, param in zip(model_params_data,
save_model.parameters()):
param.data = param_data
if self.keep_checkpoint > 0: if self.keep_checkpoint > 0:
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen: if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
@ -97,17 +98,10 @@ class ModelSaver(ModelSaverBase):
"""Simple model saver to filesystem""" """Simple model saver to filesystem"""
def _save(self, step, model): def _save(self, step, model):
real_model = (model.module model_state_dict = model.state_dict()
if isinstance(model, nn.DataParallel)
else model)
real_generator = (real_model.generator.module
if isinstance(real_model.generator, nn.DataParallel)
else real_model.generator)
model_state_dict = real_model.state_dict()
model_state_dict = {k: v for k, v in model_state_dict.items() model_state_dict = {k: v for k, v in model_state_dict.items()
if 'generator' not in k} if 'generator' not in k}
generator_state_dict = real_generator.state_dict() generator_state_dict = model.generator.state_dict()
# NOTE: We need to trim the vocab to remove any unk tokens that # NOTE: We need to trim the vocab to remove any unk tokens that
# were not originally here. # were not originally here.
@ -137,4 +131,5 @@ class ModelSaver(ModelSaverBase):
return checkpoint, checkpoint_path return checkpoint, checkpoint_path
def _rm_checkpoint(self, name): def _rm_checkpoint(self, name):
if os.path.exists(name):
os.remove(name) os.remove(name)

Просмотреть файл

@ -11,6 +11,8 @@ from onmt.modules.embeddings import Embeddings, PositionalEncoding, \
from onmt.modules.weight_norm import WeightNormConv2d from onmt.modules.weight_norm import WeightNormConv2d
from onmt.modules.average_attn import AverageAttention from onmt.modules.average_attn import AverageAttention
import onmt.modules.source_noise # noqa
__all__ = ["Elementwise", "context_gate_factory", "ContextGate", __all__ = ["Elementwise", "context_gate_factory", "ContextGate",
"GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator",
"CopyGeneratorLoss", "CopyGeneratorLossCompute", "CopyGeneratorLoss", "CopyGeneratorLossCompute",

Просмотреть файл

@ -180,7 +180,7 @@ class GlobalAttention(nn.Module):
if memory_lengths is not None: if memory_lengths is not None:
mask = sequence_mask(memory_lengths, max_len=align.size(-1)) mask = sequence_mask(memory_lengths, max_len=align.size(-1))
mask = mask.unsqueeze(1) # Make it broadcastable. mask = mask.unsqueeze(1) # Make it broadcastable.
align.masked_fill_(1 - mask, -float('inf')) align.masked_fill_(~mask, -float('inf'))
# Softmax or sparsemax to normalize attention weights # Softmax or sparsemax to normalize attention weights
if self.attn_func == "softmax": if self.attn_func == "softmax":

Просмотреть файл

@ -92,7 +92,7 @@ class MultiHeadedAttention(nn.Module):
(FloatTensor, FloatTensor): (FloatTensor, FloatTensor):
* output context vectors ``(batch, query_len, dim)`` * output context vectors ``(batch, query_len, dim)``
* one of the attention vectors ``(batch, query_len, key_len)`` * Attention vector in heads ``(batch, head, query_len, key_len)``.
""" """
# CHECKS # CHECKS
@ -219,13 +219,12 @@ class MultiHeadedAttention(nn.Module):
# aeq(batch, batch_) # aeq(batch, batch_)
# aeq(d, d_) # aeq(d, d_)
# Return one attn # Return multi-head attn
top_attn = attn \ attns = attn \
.view(batch_size, head_count, .view(batch_size, head_count,
query_len, key_len)[:, 0, :, :] \ query_len, key_len)
.contiguous()
return output, top_attn return output, attns
def update_dropout(self, dropout): def update_dropout(self, dropout):
self.dropout.p = dropout self.dropout.p = dropout

Просмотреть файл

@ -0,0 +1,350 @@
import math
import torch
def aeq(ref, *args):
for i, e in enumerate(args):
assert ref == e, "%s != %s (element %d)" % (str(ref), str(e), i)
class NoiseBase(object):
def __init__(self, prob, pad_idx=1, device_id="cpu",
ids_to_noise=[], **kwargs):
self.prob = prob
self.pad_idx = 1
self.skip_first = 1
self.device_id = device_id
self.ids_to_noise = set([t.item() for t in ids_to_noise])
def __call__(self, batch):
return self.noise_batch(batch)
def to_device(self, t):
return t.to(torch.device(self.device_id))
def noise_batch(self, batch):
source, lengths = batch.src if isinstance(batch.src, tuple) \
else (batch.src, [None] * batch.src.size(1))
# noise_skip = batch.noise_skip
# aeq(len(batch.noise_skip) == source.size(1))
# source is [src_len x bs x feats]
skipped = source[:self.skip_first, :, :]
source = source[self.skip_first:]
for i in range(source.size(1)):
if hasattr(batch, 'corpus_id'):
corpus_id = batch.corpus_id[i]
if corpus_id.item() not in self.ids_to_noise:
continue
tokens = source[:, i, 0]
mask = tokens.ne(self.pad_idx)
masked_tokens = tokens[mask]
noisy_tokens, length = self.noise_source(
masked_tokens, length=lengths[i])
lengths[i] = length
# source might increase length so we need to resize the whole
# tensor
delta = length - (source.size(0) - self.skip_first)
if delta > 0:
pad = torch.ones([delta],
device=source.device,
dtype=source.dtype)
pad *= self.pad_idx
pad = pad.unsqueeze(1).expand(-1, 15).unsqueeze(2)
source = torch.cat([source, source])
source[:noisy_tokens.size(0), i, 0] = noisy_tokens
source = torch.cat([skipped, source])
# remove useless pad
max_len = lengths.max()
source = source[:max_len, :, :]
batch.src = source, lengths
return batch
def noise_source(self, source, **kwargs):
raise NotImplementedError()
class MaskNoise(NoiseBase):
def noise_batch(self, batch):
raise ValueError("MaskNoise has not been updated to tensor noise")
# def s(self, tokens):
# prob = self.prob
# r = torch.rand([len(tokens)])
# mask = False
# masked = []
# for i, tok in enumerate(tokens):
# if tok.startswith(subword_prefix):
# if r[i].item() <= prob:
# masked.append(mask_tok)
# mask = True
# else:
# masked.append(tok)
# mask = False
# else:
# if mask:
# pass
# else:
# masked.append(tok)
# return masked
class SenShufflingNoise(NoiseBase):
def __init__(self, *args, end_of_sentence_mask=None, **kwargs):
super(SenShufflingNoise, self).__init__(*args, **kwargs)
assert end_of_sentence_mask is not None
self.end_of_sentence_mask = self.to_device(end_of_sentence_mask)
def is_end_of_sentence(self, source):
return self.end_of_sentence_mask.gather(0, source)
def noise_source(self, source, length=None, **kwargs):
# aeq(source.size(0), length)
full_stops = self.is_end_of_sentence(source)
# Pretend it ends with a full stop so last span is a sentence
full_stops[-1] = 1
# Tokens that are full stops, where the previous token is not
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero() + 2
result = source.clone()
num_sentences = sentence_ends.size(0)
num_to_permute = math.ceil((num_sentences * 2 * self.prob) / 2.0)
substitutions = torch.randperm(num_sentences)[:num_to_permute]
ordering = torch.arange(0, num_sentences)
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
index = 0
for i in ordering:
sentence = source[(sentence_ends[i - 1] if i >
0 else 1):sentence_ends[i]]
result[index:index + sentence.size(0)] = sentence
index += sentence.size(0)
# aeq(source.size(0), length)
return result, length
class InfillingNoise(NoiseBase):
def __init__(self, *args, infilling_poisson_lambda=3.0,
word_start_mask=None, **kwargs):
super(InfillingNoise, self).__init__(*args, **kwargs)
self.poisson_lambda = infilling_poisson_lambda
self.mask_span_distribution = self._make_poisson(self.poisson_lambda)
self.mask_idx = 0
assert word_start_mask is not None
self.word_start_mask = self.to_device(word_start_mask)
# -1: keep everything (i.e. 1 mask per token)
# 0: replace everything (i.e. no mask)
# 1: 1 mask per span
self.replace_length = 1
def _make_poisson(self, poisson_lambda):
# fairseq/data/denoising_dataset.py
_lambda = poisson_lambda
lambda_to_the_k = 1
e_to_the_minus_lambda = math.exp(-_lambda)
k_factorial = 1
ps = []
for k in range(0, 128):
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
lambda_to_the_k *= _lambda
k_factorial *= (k + 1)
if ps[-1] < 0.0000001:
break
ps = torch.tensor(ps, device=torch.device(self.device_id))
return torch.distributions.Categorical(ps)
def is_word_start(self, source):
# print("src size: ", source.size())
# print("ws size: ", self.word_start_mask.size())
# print("max: ", source.max())
# assert source.max() < self.word_start_mask.size(0)
# assert source.min() >= 0
return self.word_start_mask.gather(0, source)
def noise_source(self, source, **kwargs):
is_word_start = self.is_word_start(source)
# assert source.size() == is_word_start.size()
# aeq(source.eq(self.pad_idx).long().sum(), 0)
# we manually add this hypothesis since it's required for the rest
# of the function and kindof make sense
is_word_start[-1] = 0
p = self.prob
num_to_mask = (is_word_start.float().sum() * p).ceil().long()
num_inserts = 0
if num_to_mask == 0:
return source
if self.mask_span_distribution is not None:
lengths = self.mask_span_distribution.sample(
sample_shape=(num_to_mask,))
# Make sure we have enough to mask
cum_length = torch.cumsum(lengths, 0)
while cum_length[-1] < num_to_mask:
lengths = torch.cat([
lengths,
self.mask_span_distribution.sample(
sample_shape=(num_to_mask,))
], dim=0)
cum_length = torch.cumsum(lengths, 0)
# Trim to masking budget
i = 0
while cum_length[i] < num_to_mask:
i += 1
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
num_to_mask = i + 1
lengths = lengths[:num_to_mask]
# Handle 0-length mask (inserts) separately
lengths = lengths[lengths > 0]
num_inserts = num_to_mask - lengths.size(0)
num_to_mask -= num_inserts
if num_to_mask == 0:
return self.add_insertion_noise(
source, num_inserts / source.size(0))
# assert (lengths > 0).all()
else:
raise ValueError("Not supposed to be there")
lengths = torch.ones((num_to_mask,), device=source.device).long()
# assert is_word_start[-1] == 0
word_starts = is_word_start.nonzero()
indices = word_starts[torch.randperm(word_starts.size(0))[
:num_to_mask]].squeeze(1)
source_length = source.size(0)
# TODO why?
# assert source_length - 1 not in indices
to_keep = torch.ones(
source_length,
dtype=torch.bool,
device=source.device)
is_word_start = is_word_start.long()
# acts as a long length, so spans don't go over the end of doc
is_word_start[-1] = 10e5
if self.replace_length == 0:
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
# random ratio disabled
# source[indices[mask_random]] = torch.randint(
# 1, len(self.vocab), size=(mask_random.sum(),))
# if self.mask_span_distribution is not None:
# assert len(lengths.size()) == 1
# assert lengths.size() == indices.size()
lengths -= 1
while indices.size(0) > 0:
# assert lengths.size() == indices.size()
lengths -= is_word_start[indices + 1].long()
uncompleted = lengths >= 0
indices = indices[uncompleted] + 1
# mask_random = mask_random[uncompleted]
lengths = lengths[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
# random ratio disabled
# source[indices[mask_random]] = torch.randint(
# 1, len(self.vocab), size=(mask_random.sum(),))
# else:
# # A bit faster when all lengths are 1
# while indices.size(0) > 0:
# uncompleted = is_word_start[indices + 1] == 0
# indices = indices[uncompleted] + 1
# mask_random = mask_random[uncompleted]
# if self.replace_length != -1:
# # delete token
# to_keep[indices] = 0
# else:
# # keep index, but replace it with [MASK]
# source[indices] = self.mask_idx
# source[indices[mask_random]] = torch.randint(
# 1, len(self.vocab), size=(mask_random.sum(),))
# assert source_length - 1 not in indices
source = source[to_keep]
if num_inserts > 0:
source = self.add_insertion_noise(
source, num_inserts / source.size(0))
# aeq(source.eq(self.pad_idx).long().sum(), 0)
final_length = source.size(0)
return source, final_length
def add_insertion_noise(self, tokens, p):
if p == 0.0:
return tokens
num_tokens = tokens.size(0)
n = int(math.ceil(num_tokens * p))
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
noise_mask = torch.zeros(
size=(
num_tokens + n,
),
dtype=torch.bool,
device=tokens.device)
noise_mask[noise_indices] = 1
result = torch.ones([n + len(tokens)],
dtype=torch.long,
device=tokens.device) * -1
# random ratio disabled
# num_random = int(math.ceil(n * self.random_ratio))
result[noise_indices] = self.mask_idx
# result[noise_indices[:num_random]] = torch.randint(
# low=1, high=len(self.vocab), size=(num_random,))
result[~noise_mask] = tokens
# assert (result >= 0).all()
return result
class MultiNoise(NoiseBase):
NOISES = {
"sen_shuffling": SenShufflingNoise,
"infilling": InfillingNoise,
"mask": MaskNoise
}
def __init__(self, noises=[], probs=[], **kwargs):
assert len(noises) == len(probs)
super(MultiNoise, self).__init__(probs, **kwargs)
self.noises = []
for i, n in enumerate(noises):
cls = MultiNoise.NOISES.get(n)
if n is None:
raise ValueError("Unknown noise function '%s'" % n)
else:
noise = cls(probs[i], **kwargs)
self.noises.append(noise)
def noise_source(self, source, length=None, **kwargs):
for noise in self.noises:
source, length = noise.noise_source(
source, length=length, **kwargs)
return source, length

Просмотреть файл

@ -2,6 +2,8 @@
from __future__ import print_function from __future__ import print_function
import configargparse import configargparse
import onmt
from onmt.models.sru import CheckSRU from onmt.models.sru import CheckSRU
@ -12,6 +14,149 @@ def config_opts(parser):
is_write_out_config_file_arg=True, is_write_out_config_file_arg=True,
help='config file save path') help='config file save path')
def global_opts(parser):
group = parser.add_argument_group('Model')
group.add('--fp32', '-fp32', action='store_true',
help="Force the model to be in FP32 "
"because FP16 is very slow on GTX1080(ti).")
group.add('--avg_raw_probs', '-avg_raw_probs', action='store_true',
help="If this is set, during ensembling scores from "
"different models will be combined by averaging their "
"raw probabilities and then taking the log. Otherwise, "
"the log probabilities will be averaged directly. "
"Necessary for models whose output layers can assign "
"zero probability.")
group = parser.add_argument_group('Data')
group.add('--data_type', '-data_type', default="text",
help="Type of the source input. Options: [text|img].")
group.add('--shard_size', '-shard_size', type=int, default=10000,
help="Divide src and tgt (if applicable) into "
"smaller multiple src and tgt files, then "
"build shards, each shard will have "
"opt.shard_size samples except last shard. "
"shard_size=0 means no segmentation "
"shard_size>0 means segment dataset into multiple shards, "
"each shard has shard_size samples")
group.add('--output', '-output', default='pred.txt',
help="Path to output the predictions (each line will "
"be the decoded sequence")
group.add('--report_align', '-report_align', action='store_true',
help="Report alignment for each translation.")
group.add('--report_time', '-report_time', action='store_true',
help="Report some translation time metrics")
group = parser.add_argument_group('Random Sampling')
group.add('--random_sampling_topk', '-random_sampling_topk',
default=1, type=int,
help="Set this to -1 to do random sampling from full "
"distribution. Set this to value k>1 to do random "
"sampling restricted to the k most likely next tokens. "
"Set this to 1 to use argmax or for doing beam "
"search.")
group.add('--random_sampling_temp', '-random_sampling_temp',
default=1., type=float,
help="If doing random sampling, divide the logits by "
"this before computing softmax during decoding.")
group.add('--seed', '-seed', type=int, default=829,
help="Random seed")
group = parser.add_argument_group('Beam')
group.add('--beam_size', '-beam_size', type=int, default=5,
help='Beam size')
group.add('--min_length', '-min_length', type=int, default=0,
help='Minimum prediction length')
group.add('--max_length', '-max_length', type=int, default=100,
help='Maximum prediction length.')
group.add('--max_sent_length', '-max_sent_length', action=DeprecateAction,
help="Deprecated, use `-max_length` instead")
# Alpha and Beta values for Google Length + Coverage penalty
# Described here: https://arxiv.org/pdf/1609.08144.pdf, Section 7
group.add('--stepwise_penalty', '-stepwise_penalty', action='store_true',
help="Apply penalty at every decoding step. "
"Helpful for summary penalty.")
group.add('--length_penalty', '-length_penalty', default='none',
choices=['none', 'wu', 'avg'],
help="Length Penalty to use.")
group.add('--ratio', '-ratio', type=float, default=-0.,
help="Ratio based beam stop condition")
group.add('--coverage_penalty', '-coverage_penalty', default='none',
choices=['none', 'wu', 'summary'],
help="Coverage Penalty to use.")
group.add('--alpha', '-alpha', type=float, default=0.,
help="Google NMT length penalty parameter "
"(higher = longer generation)")
group.add('--beta', '-beta', type=float, default=-0.,
help="Coverage penalty parameter")
group.add('--block_ngram_repeat', '-block_ngram_repeat',
type=int, default=0,
help='Block repetition of ngrams during decoding.')
group.add('--ignore_when_blocking', '-ignore_when_blocking',
nargs='+', type=str, default=[],
help="Ignore these strings when blocking repeats. "
"You want to block sentence delimiters.")
group.add('--replace_unk', '-replace_unk', action="store_true",
help="Replace the generated UNK tokens with the "
"source token that had highest attention weight. If "
"phrase_table is provided, it will look up the "
"identified source token and give the corresponding "
"target token. If it is not provided (or the identified "
"source token does not exist in the table), then it "
"will copy the source token.")
group.add('--phrase_table', '-phrase_table', type=str, default="",
help="If phrase_table is provided (with replace_unk), it will "
"look up the identified source token and give the "
"corresponding target token. If it is not provided "
"(or the identified source token does not exist in "
"the table), then it will copy the source token.")
group = parser.add_argument_group('Logging')
group.add('--verbose', '-verbose', action="store_true",
help='Print scores and predictions for each sentence')
group.add('--log_file', '-log_file', type=str, default="",
help="Output logs to a file under this path.")
group.add('--log_file_level', '-log_file_level', type=str,
action=StoreLoggingLevelAction,
choices=StoreLoggingLevelAction.CHOICES,
default="0")
group.add('--attn_debug', '-attn_debug', action="store_true",
help='Print best attn for each word')
group.add('--align_debug', '-align_debug', action="store_true",
help='Print best align for each word')
group.add('--dump_beam', '-dump_beam', type=str, default="",
help='File to dump beam information to.')
group.add('--n_best', '-n_best', type=int, default=1,
help="If verbose is set, will output the n_best "
"decoded sentences")
group = parser.add_argument_group('Efficiency')
group.add('--batch_size', '-batch_size', type=int, default=30,
help='Batch size')
group.add('--batch_type', '-batch_type', default='sents',
choices=["sents", "tokens"],
help="Batch grouping for batch_size. Standard "
"is sents. Tokens will do dynamic batching")
group.add('--gpu', '-gpu', type=int, default=-1,
help="Device to run on")
# Options most relevant to speech.
group = parser.add_argument_group('Speech')
group.add('--sample_rate', '-sample_rate', type=int, default=16000,
help="Sample rate.")
group.add('--window_size', '-window_size', type=float, default=.02,
help='Window size for spectrogram in seconds')
group.add('--window_stride', '-window_stride', type=float, default=.01,
help='Window stride for spectrogram in seconds')
group.add('--window', '-window', default='hamming',
help='Window type for spectrogram generation')
# Option most relevant to image input
group.add('--image_channel_size', '-image_channel_size',
type=int, default=3, choices=[3, 1],
help="Using grayscale image can training "
"model faster and smaller")
def model_opts(parser): def model_opts(parser):
""" """
@ -69,10 +214,10 @@ def model_opts(parser):
help='Data type of the model.') help='Data type of the model.')
group.add('--encoder_type', '-encoder_type', type=str, default='rnn', group.add('--encoder_type', '-encoder_type', type=str, default='rnn',
choices=['rnn', 'brnn', 'mean', 'transformer', 'cnn'], choices=['rnn', 'brnn', 'ggnn', 'mean', 'transformer', 'cnn'],
help="Type of encoder layer to use. Non-RNN layers " help="Type of encoder layer to use. Non-RNN layers "
"are experimental. Options are " "are experimental. Options are "
"[rnn|brnn|mean|transformer|cnn].") "[rnn|brnn|ggnn|mean|transformer|cnn].")
group.add('--decoder_type', '-decoder_type', type=str, default='rnn', group.add('--decoder_type', '-decoder_type', type=str, default='rnn',
choices=['rnn', 'transformer', 'cnn'], choices=['rnn', 'transformer', 'cnn'],
help="Type of decoder layer to use. Non-RNN layers " help="Type of decoder layer to use. Non-RNN layers "
@ -128,6 +273,27 @@ def model_opts(parser):
help="Type of context gate to use. " help="Type of context gate to use. "
"Do not select for no context gate.") "Do not select for no context gate.")
# The following options (bridge_extra_node to src_vocab) are used
# for training with --encoder_type ggnn (Gated Graph Neural Network).
group.add('--bridge_extra_node', '-bridge_extra_node',
type=bool, default=True,
help='Graph encoder bridges only extra node to decoder as input')
group.add('--bidir_edges', '-bidir_edges', type=bool, default=True,
help='Graph encoder autogenerates bidirectional edges')
group.add('--state_dim', '-state_dim', type=int, default=512,
help='Number of state dimensions in the graph encoder')
group.add('--n_edge_types', '-n_edge_types', type=int, default=2,
help='Number of edge types in the graph encoder')
group.add('--n_node', '-n_node', type=int, default=2,
help='Number of nodes in the graph encoder')
group.add('--n_steps', '-n_steps', type=int, default=2,
help='Number of steps to advance graph encoder')
# The ggnn uses src_vocab during training because the graph is built
# using edge information which requires parsing the input sequence.
group.add('--src_vocab', '-src_vocab', default="",
help="Path to an existing source vocabulary. Format: "
"one word per line.")
# Attention options # Attention options
group = parser.add_argument_group('Model- Attention') group = parser.add_argument_group('Model- Attention')
group.add('--global_attention', '-global_attention', group.add('--global_attention', '-global_attention',
@ -154,7 +320,22 @@ def model_opts(parser):
group.add('--aan_useffn', '-aan_useffn', action="store_true", group.add('--aan_useffn', '-aan_useffn', action="store_true",
help='Turn on the FFN layer in the AAN decoder') help='Turn on the FFN layer in the AAN decoder')
# Alignement options
group = parser.add_argument_group('Model - Alignement')
group.add('--lambda_align', '-lambda_align', type=float, default=0.0,
help="Lambda value for alignement loss of Garg et al (2019)"
"For more detailed information, see: "
"https://arxiv.org/abs/1909.02074")
group.add('--alignment_layer', '-alignment_layer', type=int, default=-3,
help='Layer number which has to be supervised.')
group.add('--alignment_heads', '-alignment_heads', type=int, default=0,
help='N. of cross attention heads per layer to supervised with')
group.add('--full_context_alignment', '-full_context_alignment',
action="store_true",
help='Whether alignment is conditioned on full target context.')
# Generator and loss options. # Generator and loss options.
group = parser.add_argument_group('Generator')
group.add('--copy_attn', '-copy_attn', action="store_true", group.add('--copy_attn', '-copy_attn', action="store_true",
help='Train copy attention layer.') help='Train copy attention layer.')
group.add('--copy_attn_type', '-copy_attn_type', group.add('--copy_attn_type', '-copy_attn_type',
@ -181,7 +362,7 @@ def model_opts(parser):
group.add('--loss_scale', '-loss_scale', type=float, default=0, group.add('--loss_scale', '-loss_scale', type=float, default=0,
help="For FP16 training, the static loss scale to use. If not " help="For FP16 training, the static loss scale to use. If not "
"set, the loss scale is dynamically computed.") "set, the loss scale is dynamically computed.")
group.add('--apex_opt_level', '-apex_opt_level', type=str, default="O2", group.add('--apex_opt_level', '-apex_opt_level', type=str, default="O1",
choices=["O0", "O1", "O2", "O3"], choices=["O0", "O1", "O2", "O3"],
help="For FP16 training, the opt_level to use." help="For FP16 training, the opt_level to use."
"See https://nvidia.github.io/apex/amp.html#opt-levels.") "See https://nvidia.github.io/apex/amp.html#opt-levels.")
@ -199,12 +380,17 @@ def preprocess_opts(parser):
help="Path(s) to the training source data") help="Path(s) to the training source data")
group.add('--train_tgt', '-train_tgt', required=True, nargs='+', group.add('--train_tgt', '-train_tgt', required=True, nargs='+',
help="Path(s) to the training target data") help="Path(s) to the training target data")
group.add('--train_align', '-train_align', nargs='+', default=[None],
help="Path(s) to the training src-tgt alignment")
group.add('--train_ids', '-train_ids', nargs='+', default=[None], group.add('--train_ids', '-train_ids', nargs='+', default=[None],
help="ids to name training shards, used for corpus weighting") help="ids to name training shards, used for corpus weighting")
group.add('--valid_src', '-valid_src', group.add('--valid_src', '-valid_src',
help="Path to the validation source data") help="Path to the validation source data")
group.add('--valid_tgt', '-valid_tgt', group.add('--valid_tgt', '-valid_tgt',
help="Path to the validation target data") help="Path to the validation target data")
group.add('--valid_align', '-valid_align', default=None,
help="Path(s) to the validation src-tgt alignment")
group.add('--src_dir', '-src_dir', default="", group.add('--src_dir', '-src_dir', default="",
help="Source directory for image or audio files.") help="Source directory for image or audio files.")
@ -224,6 +410,9 @@ def preprocess_opts(parser):
"shard_size>0 means segment dataset into multiple shards, " "shard_size>0 means segment dataset into multiple shards, "
"each shard has shard_size samples") "each shard has shard_size samples")
group.add('--num_threads', '-num_threads', type=int, default=1,
help="Number of shards to build in parallel.")
group.add('--overwrite', '-overwrite', action="store_true", group.add('--overwrite', '-overwrite', action="store_true",
help="Overwrite existing shards if any.") help="Overwrite existing shards if any.")
@ -310,6 +499,15 @@ def preprocess_opts(parser):
help="Using grayscale image can training " help="Using grayscale image can training "
"model faster and smaller") "model faster and smaller")
# Options for experimental source noising (BART style)
group = parser.add_argument_group('Noise')
group.add('--subword_prefix', '-subword_prefix',
type=str, default="",
help="subword prefix to build wordstart mask")
group.add('--subword_prefix_is_joiner', '-subword_prefix_is_joiner',
action='store_true',
help="mask will need to be inverted if prefix is joiner")
def train_opts(parser): def train_opts(parser):
""" Training and saving options """ """ Training and saving options """
@ -324,6 +522,8 @@ def train_opts(parser):
group.add('--data_weights', '-data_weights', type=int, nargs='+', group.add('--data_weights', '-data_weights', type=int, nargs='+',
default=[1], help="""Weights of different corpora, default=[1], help="""Weights of different corpora,
should follow the same order as in -data_ids.""") should follow the same order as in -data_ids.""")
group.add('--data_to_noise', '-data_to_noise', nargs='+', default=[],
help="IDs of datasets on which to apply noise.")
group.add('--save_model', '-save_model', default='model', group.add('--save_model', '-save_model', default='model',
help="Model filename (the model will be saved as " help="Model filename (the model will be saved as "
@ -352,7 +552,7 @@ def train_opts(parser):
help="IP of master for torch.distributed training.") help="IP of master for torch.distributed training.")
group.add('--master_port', '-master_port', default=10000, type=int, group.add('--master_port', '-master_port', default=10000, type=int,
help="Port of master for torch.distributed training.") help="Port of master for torch.distributed training.")
group.add('--queue_size', '-queue_size', default=400, type=int, group.add('--queue_size', '-queue_size', default=40, type=int,
help="Size of queue for each process in producer/consumer") help="Size of queue for each process in producer/consumer")
group.add('--seed', '-seed', type=int, default=-1, group.add('--seed', '-seed', type=int, default=-1,
@ -472,10 +672,9 @@ def train_opts(parser):
'Typically a value of 0.999 is recommended, as this is ' 'Typically a value of 0.999 is recommended, as this is '
'the value suggested by the original paper describing ' 'the value suggested by the original paper describing '
'Adam, and is also the value adopted in other frameworks ' 'Adam, and is also the value adopted in other frameworks '
'such as Tensorflow and Kerras, i.e. see: ' 'such as Tensorflow and Keras, i.e. see: '
'https://www.tensorflow.org/api_docs/python/tf/train/Adam' 'https://www.tensorflow.org/api_docs/python/tf/train/Adam'
'Optimizer or ' 'Optimizer or https://keras.io/optimizers/ . '
'https://keras.io/optimizers/ . '
'Whereas recently the paper "Attention is All You Need" ' 'Whereas recently the paper "Attention is All You Need" '
'suggested a value of 0.98 for beta2, this parameter may ' 'suggested a value of 0.98 for beta2, this parameter may '
'not work well for normal models / default ' 'not work well for normal models / default '
@ -498,6 +697,12 @@ def train_opts(parser):
help="Step for moving average. " help="Step for moving average. "
"Default is every update, " "Default is every update, "
"if -average_decay is set.") "if -average_decay is set.")
group.add("--src_noise", "-src_noise", type=str, nargs='+',
default=[],
choices=onmt.modules.source_noise.MultiNoise.NOISES.keys())
group.add("--src_noise_prob", "-src_noise_prob", type=float, nargs='+',
default=[],
help="Probabilities of src_noise functions")
# learning rate # learning rate
group = parser.add_argument_group('Optimization- Rate') group = parser.add_argument_group('Optimization- Rate')
@ -536,10 +741,10 @@ def train_opts(parser):
help="Send logs to this crayon server.") help="Send logs to this crayon server.")
group.add('--exp', '-exp', type=str, default="", group.add('--exp', '-exp', type=str, default="",
help="Name of the experiment for logging.") help="Name of the experiment for logging.")
# Use TensorboardX for visualization during training # Use Tensorboard for visualization during training
group.add('--tensorboard', '-tensorboard', action="store_true", group.add('--tensorboard', '-tensorboard', action="store_true",
help="Use tensorboardX for visualization during training. " help="Use tensorboard for visualization during training. "
"Must have the library tensorboardX.") "Must have the library tensorboard >= 1.14.")
group.add("--tensorboard_log_dir", "-tensorboard_log_dir", group.add("--tensorboard_log_dir", "-tensorboard_log_dir",
type=str, default="runs/onmt", type=str, default="runs/onmt",
help="Log directory for Tensorboard. " help="Log directory for Tensorboard. "
@ -600,12 +805,8 @@ def translate_opts(parser):
group.add('--output', '-output', default='pred.txt', group.add('--output', '-output', default='pred.txt',
help="Path to output the predictions (each line will " help="Path to output the predictions (each line will "
"be the decoded sequence") "be the decoded sequence")
group.add('--report_bleu', '-report_bleu', action='store_true', group.add('--report_align', '-report_align', action='store_true',
help="Report bleu score after translation, " help="Report alignment for each translation.")
"call tools/multi-bleu.perl on command line")
group.add('--report_rouge', '-report_rouge', action='store_true',
help="Report rouge 1/2/3/L/SU4 score after translation "
"call tools/test_rouge.py on command line")
group.add('--report_time', '-report_time', action='store_true', group.add('--report_time', '-report_time', action='store_true',
help="Report some translation time metrics") help="Report some translation time metrics")
@ -690,6 +891,8 @@ def translate_opts(parser):
default="0") default="0")
group.add('--attn_debug', '-attn_debug', action="store_true", group.add('--attn_debug', '-attn_debug', action="store_true",
help='Print best attn for each word') help='Print best attn for each word')
group.add('--align_debug', '-align_debug', action="store_true",
help='Print best align for each word')
group.add('--dump_beam', '-dump_beam', type=str, default="", group.add('--dump_beam', '-dump_beam', type=str, default="",
help='File to dump beam information to.') help='File to dump beam information to.')
group.add('--n_best', '-n_best', type=int, default=1, group.add('--n_best', '-n_best', type=int, default=1,
@ -699,6 +902,10 @@ def translate_opts(parser):
group = parser.add_argument_group('Efficiency') group = parser.add_argument_group('Efficiency')
group.add('--batch_size', '-batch_size', type=int, default=30, group.add('--batch_size', '-batch_size', type=int, default=30,
help='Batch size') help='Batch size')
group.add('--batch_type', '-batch_type', default='sents',
choices=["sents", "tokens"],
help="Batch grouping for batch_size. Standard "
"is sents. Tokens will do dynamic batching")
group.add('--gpu', '-gpu', type=int, default=-1, group.add('--gpu', '-gpu', type=int, default=-1,
help="Device to run on") help="Device to run on")

0
opennmt/onmt/tests/pull_request_chk.sh Normal file → Executable file
Просмотреть файл

0
opennmt/onmt/tests/rebuild_test_models.sh Normal file → Executable file
Просмотреть файл

Просмотреть файл

@ -1,383 +0,0 @@
import unittest
from onmt.translate.beam import Beam, GNMTGlobalScorer
import torch
class GlobalScorerStub(object):
def update_global_state(self, beam):
pass
def score(self, beam, scores):
return scores
class TestBeam(unittest.TestCase):
BLOCKED_SCORE = -10e20
def test_advance_with_all_repeats_gets_blocked(self):
# all beams repeat (beam >= 1 repeat dummy scores)
beam_sz = 5
n_words = 100
repeat_idx = 47
ngram_repeat = 3
beam = Beam(beam_sz, 0, 1, 2, n_best=2,
exclusion_tokens=set(),
global_scorer=GlobalScorerStub(),
block_ngram_repeat=ngram_repeat)
for i in range(ngram_repeat + 4):
# predict repeat_idx over and over again
word_probs = torch.full((beam_sz, n_words), -float('inf'))
word_probs[0, repeat_idx] = 0
attns = torch.randn(beam_sz)
beam.advance(word_probs, attns)
if i <= ngram_repeat:
self.assertTrue(
beam.scores.equal(
torch.tensor(
[0] + [-float('inf')] * (beam_sz - 1))))
else:
self.assertTrue(
beam.scores.equal(torch.tensor(
[self.BLOCKED_SCORE] * beam_sz)))
def test_advance_with_some_repeats_gets_blocked(self):
# beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
beam_sz = 5
n_words = 100
repeat_idx = 47
ngram_repeat = 3
beam = Beam(beam_sz, 0, 1, 2, n_best=2,
exclusion_tokens=set(),
global_scorer=GlobalScorerStub(),
block_ngram_repeat=ngram_repeat)
for i in range(ngram_repeat + 4):
# non-interesting beams are going to get dummy values
word_probs = torch.full((beam_sz, n_words), -float('inf'))
if i == 0:
# on initial round, only predicted scores for beam 0
# matter. Make two predictions. Top one will be repeated
# in beam zero, second one will live on in beam 1.
word_probs[0, repeat_idx] = -0.1
word_probs[0, repeat_idx + i + 1] = -2.3
else:
# predict the same thing in beam 0
word_probs[0, repeat_idx] = 0
# continue pushing around what beam 1 predicts
word_probs[1, repeat_idx + i + 1] = 0
attns = torch.randn(beam_sz)
beam.advance(word_probs, attns)
if i <= ngram_repeat:
self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE))
else:
# now beam 0 dies (along with the others), beam 1 -> beam 0
self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
self.assertTrue(
beam.scores[1:].equal(torch.tensor(
[self.BLOCKED_SCORE] * (beam_sz - 1))))
def test_repeating_excluded_index_does_not_die(self):
# beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx)
beam_sz = 5
n_words = 100
repeat_idx = 47 # will be repeated and should be blocked
repeat_idx_ignored = 7 # will be repeated and should not be blocked
ngram_repeat = 3
beam = Beam(beam_sz, 0, 1, 2, n_best=2,
exclusion_tokens=set([repeat_idx_ignored]),
global_scorer=GlobalScorerStub(),
block_ngram_repeat=ngram_repeat)
for i in range(ngram_repeat + 4):
# non-interesting beams are going to get dummy values
word_probs = torch.full((beam_sz, n_words), -float('inf'))
if i == 0:
word_probs[0, repeat_idx] = -0.1
word_probs[0, repeat_idx + i + 1] = -2.3
word_probs[0, repeat_idx_ignored] = -5.0
else:
# predict the same thing in beam 0
word_probs[0, repeat_idx] = 0
# continue pushing around what beam 1 predicts
word_probs[1, repeat_idx + i + 1] = 0
# predict the allowed-repeat again in beam 2
word_probs[2, repeat_idx_ignored] = 0
attns = torch.randn(beam_sz)
beam.advance(word_probs, attns)
if i <= ngram_repeat:
self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE))
self.assertFalse(beam.scores[2].eq(self.BLOCKED_SCORE))
else:
# now beam 0 dies, beam 1 -> beam 0, beam 2 -> beam 1
# and the rest die
self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE))
# since all preds after i=0 are 0, we can check
# that the beam is the correct idx by checking that
# the curr score is the initial score
self.assertTrue(beam.scores[0].eq(-2.3))
self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE))
self.assertTrue(beam.scores[1].eq(-5.0))
self.assertTrue(
beam.scores[2:].equal(torch.tensor(
[self.BLOCKED_SCORE] * (beam_sz - 2))))
def test_doesnt_predict_eos_if_shorter_than_min_len(self):
# beam 0 will always predict EOS. The other beams will predict
# non-eos scores.
# this is also a test that when block_ngram_repeat=0,
# repeating is acceptable
beam_sz = 5
n_words = 100
_non_eos_idxs = [47, 51, 13, 88, 99]
valid_score_dist = torch.log_softmax(torch.tensor(
[6., 5., 4., 3., 2., 1.]), dim=0)
min_length = 5
eos_idx = 2
beam = Beam(beam_sz, 0, 1, eos_idx, n_best=2,
exclusion_tokens=set(),
min_length=min_length,
global_scorer=GlobalScorerStub(),
block_ngram_repeat=0)
for i in range(min_length + 4):
# non-interesting beams are going to get dummy values
word_probs = torch.full((beam_sz, n_words), -float('inf'))
if i == 0:
# "best" prediction is eos - that should be blocked
word_probs[0, eos_idx] = valid_score_dist[0]
# include at least beam_sz predictions OTHER than EOS
# that are greater than -1e20
for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
word_probs[0, j] = score
else:
# predict eos in beam 0
word_probs[0, eos_idx] = valid_score_dist[0]
# provide beam_sz other good predictions
for k, (j, score) in enumerate(
zip(_non_eos_idxs, valid_score_dist[1:])):
beam_idx = min(beam_sz-1, k)
word_probs[beam_idx, j] = score
attns = torch.randn(beam_sz)
beam.advance(word_probs, attns)
if i < min_length:
expected_score_dist = (i+1) * valid_score_dist[1:]
self.assertTrue(beam.scores.allclose(expected_score_dist))
elif i == min_length:
# now the top beam has ended and no others have
# first beam finished had length beam.min_length
self.assertEqual(beam.finished[0][1], beam.min_length + 1)
# first beam finished was 0
self.assertEqual(beam.finished[0][2], 0)
else: # i > min_length
# not of interest, but want to make sure it keeps running
# since only beam 0 terminates and n_best = 2
pass
def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
# this is also a test that when block_ngram_repeat=0,
# repeating is acceptable
beam_sz = 5
n_words = 100
_non_eos_idxs = [47, 51, 13, 88, 99]
valid_score_dist = torch.log_softmax(torch.tensor(
[6., 5., 4., 3., 2., 1.]), dim=0)
min_length = 5
eos_idx = 2
beam = Beam(beam_sz, 0, 1, eos_idx, n_best=2,
exclusion_tokens=set(),
min_length=min_length,
global_scorer=GlobalScorerStub(),
block_ngram_repeat=0)
for i in range(min_length + 4):
# non-interesting beams are going to get dummy values
word_probs = torch.full((beam_sz, n_words), -float('inf'))
if i == 0:
# "best" prediction is eos - that should be blocked
word_probs[0, eos_idx] = valid_score_dist[0]
# include at least beam_sz predictions OTHER than EOS
# that are greater than -1e20
for j, score in zip(_non_eos_idxs, valid_score_dist[1:]):
word_probs[0, j] = score
elif i <= min_length:
# predict eos in beam 1
word_probs[1, eos_idx] = valid_score_dist[0]
# provide beam_sz other good predictions in other beams
for k, (j, score) in enumerate(
zip(_non_eos_idxs, valid_score_dist[1:])):
beam_idx = min(beam_sz-1, k)
word_probs[beam_idx, j] = score
else:
word_probs[0, eos_idx] = valid_score_dist[0]
word_probs[1, eos_idx] = valid_score_dist[0]
# provide beam_sz other good predictions in other beams
for k, (j, score) in enumerate(
zip(_non_eos_idxs, valid_score_dist[1:])):
beam_idx = min(beam_sz-1, k)
word_probs[beam_idx, j] = score
attns = torch.randn(beam_sz)
beam.advance(word_probs, attns)
if i < min_length:
self.assertFalse(beam.done)
elif i == min_length:
# beam 1 dies on min_length
self.assertEqual(beam.finished[0][1], beam.min_length + 1)
self.assertEqual(beam.finished[0][2], 1)
self.assertFalse(beam.done)
else: # i > min_length
# beam 0 dies on the step after beam 1 dies
self.assertEqual(beam.finished[1][1], beam.min_length + 2)
self.assertEqual(beam.finished[1][2], 0)
self.assertTrue(beam.done)
class TestBeamAgainstReferenceCase(unittest.TestCase):
BEAM_SZ = 5
EOS_IDX = 2 # don't change this - all the scores would need updated
N_WORDS = 8 # also don't change for same reason
N_BEST = 3
DEAD_SCORE = -1e20
INP_SEQ_LEN = 53
def init_step(self, beam):
# init_preds: [4, 3, 5, 6, 7] - no EOS's
init_scores = torch.log_softmax(torch.tensor(
[[0, 0, 0, 4, 5, 3, 2, 1]], dtype=torch.float), dim=1)
expected_beam_scores, expected_preds_0 = init_scores.topk(self.BEAM_SZ)
beam.advance(init_scores, torch.randn(self.BEAM_SZ, self.INP_SEQ_LEN))
self.assertTrue(beam.scores.allclose(expected_beam_scores))
self.assertTrue(beam.next_ys[-1].equal(expected_preds_0[0]))
self.assertFalse(beam.eos_top)
self.assertFalse(beam.done)
return expected_beam_scores
def first_step(self, beam, expected_beam_scores, expected_len_pen):
# no EOS's yet
assert len(beam.finished) == 0
scores_1 = torch.log_softmax(torch.tensor(
[[0, 0, 0, .3, 0, .51, .2, 0],
[0, 0, 1.5, 0, 0, 0, 0, 0],
[0, 0, 0, 0, .49, .48, 0, 0],
[0, 0, 0, .2, .2, .2, .2, .2],
[0, 0, 0, .2, .2, .2, .2, .2]]
), dim=1)
beam.advance(scores_1, torch.randn(self.BEAM_SZ, self.INP_SEQ_LEN))
new_scores = scores_1 + expected_beam_scores.t()
expected_beam_scores, unreduced_preds = new_scores.view(-1).topk(
self.BEAM_SZ, 0, True, True)
expected_bptr_1 = unreduced_preds / self.N_WORDS
# [5, 3, 2, 6, 0], so beam 2 predicts EOS!
expected_preds_1 = unreduced_preds - expected_bptr_1 * self.N_WORDS
self.assertTrue(beam.scores.allclose(expected_beam_scores))
self.assertTrue(beam.next_ys[-1].equal(expected_preds_1))
self.assertTrue(beam.prev_ks[-1].equal(expected_bptr_1))
self.assertEqual(len(beam.finished), 1)
self.assertEqual(beam.finished[0][2], 2) # beam 2 finished
self.assertEqual(beam.finished[0][1], 2) # finished on second step
self.assertEqual(beam.finished[0][0], # finished with correct score
expected_beam_scores[2] / expected_len_pen)
self.assertFalse(beam.eos_top)
self.assertFalse(beam.done)
return expected_beam_scores
def second_step(self, beam, expected_beam_scores, expected_len_pen):
# assumes beam 2 finished on last step
scores_2 = torch.log_softmax(torch.tensor(
[[0, 0, 0, .3, 0, .51, .2, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 5000, .48, 0, 0], # beam 2 shouldn't continue
[0, 0, 50, .2, .2, .2, .2, .2], # beam 3 -> beam 0 should die
[0, 0, 0, .2, .2, .2, .2, .2]]
), dim=1)
beam.advance(scores_2, torch.randn(self.BEAM_SZ, self.INP_SEQ_LEN))
new_scores = scores_2 + expected_beam_scores.unsqueeze(1)
new_scores[2] = self.DEAD_SCORE # ended beam 2 shouldn't continue
expected_beam_scores, unreduced_preds = new_scores.view(-1).topk(
self.BEAM_SZ, 0, True, True)
expected_bptr_2 = unreduced_preds / self.N_WORDS
# [2, 5, 3, 6, 0], so beam 0 predicts EOS!
expected_preds_2 = unreduced_preds - expected_bptr_2 * self.N_WORDS
# [-2.4879, -3.8910, -4.1010, -4.2010, -4.4010]
self.assertTrue(beam.scores.allclose(expected_beam_scores))
self.assertTrue(beam.next_ys[-1].equal(expected_preds_2))
self.assertTrue(beam.prev_ks[-1].equal(expected_bptr_2))
self.assertEqual(len(beam.finished), 2)
# new beam 0 finished
self.assertEqual(beam.finished[1][2], 0)
# new beam 0 is old beam 3
self.assertEqual(expected_bptr_2[0], 3)
self.assertEqual(beam.finished[1][1], 3) # finished on third step
self.assertEqual(beam.finished[1][0], # finished with correct score
expected_beam_scores[0] / expected_len_pen)
self.assertTrue(beam.eos_top)
self.assertFalse(beam.done)
return expected_beam_scores
def third_step(self, beam, expected_beam_scores, expected_len_pen):
# assumes beam 0 finished on last step
scores_3 = torch.log_softmax(torch.tensor(
[[0, 0, 5000, 0, 5000, .51, .2, 0], # beam 0 shouldn't cont
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 5000, 0, 0],
[0, 0, 0, .2, .2, .2, .2, .2],
[0, 0, 50, 0, .2, .2, .2, .2]] # beam 4 -> beam 1 should die
), dim=1)
beam.advance(scores_3, torch.randn(self.BEAM_SZ, self.INP_SEQ_LEN))
new_scores = scores_3 + expected_beam_scores.unsqueeze(1)
new_scores[0] = self.DEAD_SCORE # ended beam 2 shouldn't continue
expected_beam_scores, unreduced_preds = new_scores.view(-1).topk(
self.BEAM_SZ, 0, True, True)
expected_bptr_3 = unreduced_preds / self.N_WORDS
# [5, 2, 6, 1, 0], so beam 1 predicts EOS!
expected_preds_3 = unreduced_preds - expected_bptr_3 * self.N_WORDS
self.assertTrue(beam.scores.allclose(expected_beam_scores))
self.assertTrue(beam.next_ys[-1].equal(expected_preds_3))
self.assertTrue(beam.prev_ks[-1].equal(expected_bptr_3))
self.assertEqual(len(beam.finished), 3)
# new beam 1 finished
self.assertEqual(beam.finished[2][2], 1)
# new beam 1 is old beam 4
self.assertEqual(expected_bptr_3[1], 4)
self.assertEqual(beam.finished[2][1], 4) # finished on fourth step
self.assertEqual(beam.finished[2][0], # finished with correct score
expected_beam_scores[1] / expected_len_pen)
self.assertTrue(beam.eos_top)
self.assertTrue(beam.done)
return expected_beam_scores
def test_beam_advance_against_known_reference(self):
beam = Beam(self.BEAM_SZ, 0, 1, self.EOS_IDX, n_best=self.N_BEST,
exclusion_tokens=set(),
min_length=0,
global_scorer=GlobalScorerStub(),
block_ngram_repeat=0)
expected_beam_scores = self.init_step(beam)
expected_beam_scores = self.first_step(beam, expected_beam_scores, 1)
expected_beam_scores = self.second_step(beam, expected_beam_scores, 1)
self.third_step(beam, expected_beam_scores, 1)
class TestBeamWithLengthPenalty(TestBeamAgainstReferenceCase):
# this could be considered an integration test because it tests
# interactions between the GNMT scorer and the beam
def test_beam_advance_against_known_reference(self):
scorer = GNMTGlobalScorer(0.7, 0., "avg", "none")
beam = Beam(self.BEAM_SZ, 0, 1, self.EOS_IDX, n_best=self.N_BEST,
exclusion_tokens=set(),
min_length=0,
global_scorer=scorer,
block_ngram_repeat=0)
expected_beam_scores = self.init_step(beam)
expected_beam_scores = self.first_step(beam, expected_beam_scores, 3)
expected_beam_scores = self.second_step(beam, expected_beam_scores, 4)
self.third_step(beam, expected_beam_scores, 5)

Просмотреть файл

@ -1,6 +1,5 @@
import unittest import unittest
from onmt.translate.beam import GNMTGlobalScorer from onmt.translate.beam_search import BeamSearch, GNMTGlobalScorer
from onmt.translate.beam_search import BeamSearch
from copy import deepcopy from copy import deepcopy
@ -34,29 +33,49 @@ class TestBeamSearch(unittest.TestCase):
n_words = 100 n_words = 100
repeat_idx = 47 repeat_idx = 47
ngram_repeat = 3 ngram_repeat = 3
device_init = torch.zeros(1, 1)
for batch_sz in [1, 3]: for batch_sz in [1, 3]:
beam = BeamSearch( beam = BeamSearch(
beam_sz, batch_sz, 0, 1, 2, 2, beam_sz, batch_sz, 0, 1, 2, 2,
torch.device("cpu"), GlobalScorerStub(), 0, 30, GlobalScorerStub(), 0, 30,
False, ngram_repeat, set(), False, ngram_repeat, set(),
torch.randint(0, 30, (batch_sz,)), False, 0.) False, 0.)
beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
for i in range(ngram_repeat + 4): for i in range(ngram_repeat + 4):
# predict repeat_idx over and over again # predict repeat_idx over and over again
word_probs = torch.full( word_probs = torch.full(
(batch_sz * beam_sz, n_words), -float('inf')) (batch_sz * beam_sz, n_words), -float('inf'))
word_probs[0::beam_sz, repeat_idx] = 0 word_probs[0::beam_sz, repeat_idx] = 0
attns = torch.randn(1, batch_sz * beam_sz, 53) attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns) beam.advance(word_probs, attns)
if i <= ngram_repeat:
if i < ngram_repeat:
# before repeat, scores are either 0 or -inf
expected_scores = torch.tensor( expected_scores = torch.tensor(
[0] + [-float('inf')] * (beam_sz - 1))\ [0] + [-float('inf')] * (beam_sz - 1))\
.repeat(batch_sz, 1) .repeat(batch_sz, 1)
self.assertTrue(beam.topk_log_probs.equal(expected_scores)) self.assertTrue(beam.topk_log_probs.equal(expected_scores))
elif i % ngram_repeat == 0:
# on repeat, `repeat_idx` score is BLOCKED_SCORE
# (but it's still the best score, thus we have
# [BLOCKED_SCORE, -inf, -inf, -inf, -inf]
expected_scores = torch.tensor(
[0] + [-float('inf')] * (beam_sz - 1))\
.repeat(batch_sz, 1)
expected_scores[:, 0] = self.BLOCKED_SCORE
self.assertTrue(beam.topk_log_probs.equal(expected_scores))
else: else:
self.assertTrue( # repetitions keeps maximizing score
beam.topk_log_probs.equal( # index 0 has been blocked, so repeating=>+0.0 score
torch.tensor(self.BLOCKED_SCORE) # other indexes are -inf so repeating=>BLOCKED_SCORE
.repeat(batch_sz, beam_sz))) # which is higher
expected_scores = torch.tensor(
[0] + [-float('inf')] * (beam_sz - 1))\
.repeat(batch_sz, 1)
expected_scores[:, :] = self.BLOCKED_SCORE
expected_scores = torch.tensor(
self.BLOCKED_SCORE).repeat(batch_sz, beam_sz)
def test_advance_with_some_repeats_gets_blocked(self): def test_advance_with_some_repeats_gets_blocked(self):
# beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores) # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
@ -64,12 +83,16 @@ class TestBeamSearch(unittest.TestCase):
n_words = 100 n_words = 100
repeat_idx = 47 repeat_idx = 47
ngram_repeat = 3 ngram_repeat = 3
no_repeat_score = -2.3
repeat_score = -0.1
device_init = torch.zeros(1, 1)
for batch_sz in [1, 3]: for batch_sz in [1, 3]:
beam = BeamSearch( beam = BeamSearch(
beam_sz, batch_sz, 0, 1, 2, 2, beam_sz, batch_sz, 0, 1, 2, 2,
torch.device("cpu"), GlobalScorerStub(), 0, 30, GlobalScorerStub(), 0, 30,
False, ngram_repeat, set(), False, ngram_repeat, set(),
torch.randint(0, 30, (batch_sz,)), False, 0.) False, 0.)
beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
for i in range(ngram_repeat + 4): for i in range(ngram_repeat + 4):
# non-interesting beams are going to get dummy values # non-interesting beams are going to get dummy values
word_probs = torch.full( word_probs = torch.full(
@ -78,8 +101,9 @@ class TestBeamSearch(unittest.TestCase):
# on initial round, only predicted scores for beam 0 # on initial round, only predicted scores for beam 0
# matter. Make two predictions. Top one will be repeated # matter. Make two predictions. Top one will be repeated
# in beam zero, second one will live on in beam 1. # in beam zero, second one will live on in beam 1.
word_probs[0::beam_sz, repeat_idx] = -0.1 word_probs[0::beam_sz, repeat_idx] = repeat_score
word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3 word_probs[0::beam_sz, repeat_idx +
i + 1] = no_repeat_score
else: else:
# predict the same thing in beam 0 # predict the same thing in beam 0
word_probs[0::beam_sz, repeat_idx] = 0 word_probs[0::beam_sz, repeat_idx] = 0
@ -87,22 +111,35 @@ class TestBeamSearch(unittest.TestCase):
word_probs[1::beam_sz, repeat_idx + i + 1] = 0 word_probs[1::beam_sz, repeat_idx + i + 1] = 0
attns = torch.randn(1, batch_sz * beam_sz, 53) attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns) beam.advance(word_probs, attns)
if i <= ngram_repeat: if i < ngram_repeat:
self.assertFalse( self.assertFalse(
beam.topk_log_probs[0::beam_sz].eq( beam.topk_log_probs[0::beam_sz].eq(
self.BLOCKED_SCORE).any()) self.BLOCKED_SCORE).any())
self.assertFalse( self.assertFalse(
beam.topk_log_probs[1::beam_sz].eq( beam.topk_log_probs[1::beam_sz].eq(
self.BLOCKED_SCORE).any()) self.BLOCKED_SCORE).any())
elif i == ngram_repeat:
# now beam 0 dies (along with the others), beam 1 -> beam 0
self.assertFalse(
beam.topk_log_probs[:, 0].eq(
self.BLOCKED_SCORE).any())
expected = torch.full([batch_sz, beam_sz], float("-inf"))
expected[:, 0] = no_repeat_score
expected[:, 1] = self.BLOCKED_SCORE
self.assertTrue(
beam.topk_log_probs[:, :].equal(expected))
else: else:
# now beam 0 dies (along with the others), beam 1 -> beam 0 # now beam 0 dies (along with the others), beam 1 -> beam 0
self.assertFalse( self.assertFalse(
beam.topk_log_probs[:, 0].eq( beam.topk_log_probs[:, 0].eq(
self.BLOCKED_SCORE).any()) self.BLOCKED_SCORE).any())
expected = torch.full([batch_sz, beam_sz], float("-inf"))
expected[:, 0] = no_repeat_score
expected[:, 1:] = self.BLOCKED_SCORE
self.assertTrue( self.assertTrue(
beam.topk_log_probs[:, 1:].equal( beam.topk_log_probs.equal(expected))
torch.tensor(self.BLOCKED_SCORE)
.repeat(batch_sz, beam_sz-1)))
def test_repeating_excluded_index_does_not_die(self): def test_repeating_excluded_index_does_not_die(self):
# beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx) # beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx)
@ -111,12 +148,14 @@ class TestBeamSearch(unittest.TestCase):
repeat_idx = 47 # will be repeated and should be blocked repeat_idx = 47 # will be repeated and should be blocked
repeat_idx_ignored = 7 # will be repeated and should not be blocked repeat_idx_ignored = 7 # will be repeated and should not be blocked
ngram_repeat = 3 ngram_repeat = 3
device_init = torch.zeros(1, 1)
for batch_sz in [1, 3]: for batch_sz in [1, 3]:
beam = BeamSearch( beam = BeamSearch(
beam_sz, batch_sz, 0, 1, 2, 2, beam_sz, batch_sz, 0, 1, 2, 2,
torch.device("cpu"), GlobalScorerStub(), 0, 30, GlobalScorerStub(), 0, 30,
False, ngram_repeat, {repeat_idx_ignored}, False, ngram_repeat, {repeat_idx_ignored},
torch.randint(0, 30, (batch_sz,)), False, 0.) False, 0.)
beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
for i in range(ngram_repeat + 4): for i in range(ngram_repeat + 4):
# non-interesting beams are going to get dummy values # non-interesting beams are going to get dummy values
word_probs = torch.full( word_probs = torch.full(
@ -134,7 +173,7 @@ class TestBeamSearch(unittest.TestCase):
word_probs[2::beam_sz, repeat_idx_ignored] = 0 word_probs[2::beam_sz, repeat_idx_ignored] = 0
attns = torch.randn(1, batch_sz * beam_sz, 53) attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns) beam.advance(word_probs, attns)
if i <= ngram_repeat: if i < ngram_repeat:
self.assertFalse(beam.topk_log_probs[:, 0].eq( self.assertFalse(beam.topk_log_probs[:, 0].eq(
self.BLOCKED_SCORE).any()) self.BLOCKED_SCORE).any())
self.assertFalse(beam.topk_log_probs[:, 1].eq( self.assertFalse(beam.topk_log_probs[:, 1].eq(
@ -153,10 +192,9 @@ class TestBeamSearch(unittest.TestCase):
self.assertFalse(beam.topk_log_probs[:, 1].eq( self.assertFalse(beam.topk_log_probs[:, 1].eq(
self.BLOCKED_SCORE).all()) self.BLOCKED_SCORE).all())
self.assertTrue(beam.topk_log_probs[:, 1].eq(-5.0).all()) self.assertTrue(beam.topk_log_probs[:, 1].eq(-5.0).all())
self.assertTrue(
beam.topk_log_probs[:, 2:].equal( self.assertTrue(beam.topk_log_probs[:, 2].eq(
torch.tensor(self.BLOCKED_SCORE) self.BLOCKED_SCORE).all())
.repeat(batch_sz, beam_sz - 2)))
def test_doesnt_predict_eos_if_shorter_than_min_len(self): def test_doesnt_predict_eos_if_shorter_than_min_len(self):
# beam 0 will always predict EOS. The other beams will predict # beam 0 will always predict EOS. The other beams will predict
@ -171,9 +209,11 @@ class TestBeamSearch(unittest.TestCase):
eos_idx = 2 eos_idx = 2
lengths = torch.randint(0, 30, (batch_sz,)) lengths = torch.randint(0, 30, (batch_sz,))
beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2, beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2,
torch.device("cpu"), GlobalScorerStub(), GlobalScorerStub(),
min_length, 30, False, 0, set(), min_length, 30, False, 0, set(),
lengths, False, 0.) False, 0.)
device_init = torch.zeros(1, 1)
beam.initialize(device_init, lengths)
all_attns = [] all_attns = []
for i in range(min_length + 4): for i in range(min_length + 4):
# non-interesting beams are going to get dummy values # non-interesting beams are going to get dummy values
@ -226,9 +266,11 @@ class TestBeamSearch(unittest.TestCase):
eos_idx = 2 eos_idx = 2
beam = BeamSearch( beam = BeamSearch(
beam_sz, batch_sz, 0, 1, 2, 2, beam_sz, batch_sz, 0, 1, 2, 2,
torch.device("cpu"), GlobalScorerStub(), GlobalScorerStub(),
min_length, 30, False, 0, set(), min_length, 30, False, 0, set(),
torch.randint(0, 30, (batch_sz,)), False, 0.) False, 0.)
device_init = torch.zeros(1, 1)
beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
for i in range(min_length + 4): for i in range(min_length + 4):
# non-interesting beams are going to get dummy values # non-interesting beams are going to get dummy values
word_probs = torch.full( word_probs = torch.full(
@ -284,9 +326,12 @@ class TestBeamSearch(unittest.TestCase):
inp_lens = torch.randint(1, 30, (batch_sz,)) inp_lens = torch.randint(1, 30, (batch_sz,))
beam = BeamSearch( beam = BeamSearch(
beam_sz, batch_sz, 0, 1, 2, 2, beam_sz, batch_sz, 0, 1, 2, 2,
torch.device("cpu"), GlobalScorerStub(), GlobalScorerStub(),
min_length, 30, True, 0, set(), min_length, 30, True, 0, set(),
inp_lens, False, 0.) False, 0.)
device_init = torch.zeros(1, 1)
_, _, inp_lens, _ = beam.initialize(device_init, inp_lens)
# inp_lens is tiled in initialize, reassign to make attn match
for i in range(min_length + 2): for i in range(min_length + 2):
# non-interesting beams are going to get dummy values # non-interesting beams are going to get dummy values
word_probs = torch.full( word_probs = torch.full(
@ -495,10 +540,11 @@ class TestBeamSearchAgainstReferenceCase(unittest.TestCase):
def test_beam_advance_against_known_reference(self): def test_beam_advance_against_known_reference(self):
beam = BeamSearch( beam = BeamSearch(
self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST, self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
torch.device("cpu"), GlobalScorerStub(), GlobalScorerStub(),
0, 30, False, 0, set(), 0, 30, False, 0, set(),
torch.randint(0, 30, (self.BATCH_SZ,)), False, 0.) False, 0.)
device_init = torch.zeros(1, 1)
beam.initialize(device_init, torch.randint(0, 30, (self.BATCH_SZ,)))
expected_beam_scores = self.init_step(beam, 1) expected_beam_scores = self.init_step(beam, 1)
expected_beam_scores = self.first_step(beam, expected_beam_scores, 1) expected_beam_scores = self.first_step(beam, expected_beam_scores, 1)
expected_beam_scores = self.second_step(beam, expected_beam_scores, 1) expected_beam_scores = self.second_step(beam, expected_beam_scores, 1)
@ -513,9 +559,11 @@ class TestBeamWithLengthPenalty(TestBeamSearchAgainstReferenceCase):
scorer = GNMTGlobalScorer(0.7, 0., "avg", "none") scorer = GNMTGlobalScorer(0.7, 0., "avg", "none")
beam = BeamSearch( beam = BeamSearch(
self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST, self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
torch.device("cpu"), scorer, scorer,
0, 30, False, 0, set(), 0, 30, False, 0, set(),
torch.randint(0, 30, (self.BATCH_SZ,)), False, 0.) False, 0.)
device_init = torch.zeros(1, 1)
beam.initialize(device_init, torch.randint(0, 30, (self.BATCH_SZ,)))
expected_beam_scores = self.init_step(beam, 1.) expected_beam_scores = self.init_step(beam, 1.)
expected_beam_scores = self.first_step(beam, expected_beam_scores, 3) expected_beam_scores = self.first_step(beam, expected_beam_scores, 3)
expected_beam_scores = self.second_step(beam, expected_beam_scores, 4) expected_beam_scores = self.second_step(beam, expected_beam_scores, 4)

Просмотреть файл

@ -1,114 +1,16 @@
import unittest import unittest
from onmt.translate.random_sampling import RandomSampling from onmt.translate.greedy_search import GreedySearch
import torch import torch
class TestRandomSampling(unittest.TestCase): class TestGreedySearch(unittest.TestCase):
BATCH_SZ = 3 BATCH_SZ = 3
INP_SEQ_LEN = 53 INP_SEQ_LEN = 53
DEAD_SCORE = -1e20 DEAD_SCORE = -1e20
BLOCKED_SCORE = -10e20 BLOCKED_SCORE = -10e20
def test_advance_with_repeats_gets_blocked(self):
n_words = 100
repeat_idx = 47
ngram_repeat = 3
for batch_sz in [1, 3]:
samp = RandomSampling(
0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat, set(),
False, 30, 1., 5, torch.randint(0, 30, (batch_sz,)))
for i in range(ngram_repeat + 4):
# predict repeat_idx over and over again
word_probs = torch.full(
(batch_sz, n_words), -float('inf'))
word_probs[:, repeat_idx] = 0
attns = torch.randn(1, batch_sz, 53)
samp.advance(word_probs, attns)
if i <= ngram_repeat:
expected_scores = torch.zeros((batch_sz, 1))
self.assertTrue(samp.topk_scores.equal(expected_scores))
else:
self.assertTrue(
samp.topk_scores.equal(
torch.tensor(self.BLOCKED_SCORE)
.repeat(batch_sz, 1)))
def test_advance_with_some_repeats_gets_blocked(self):
# batch 0 and 7 will repeat, the rest will advance
n_words = 100
repeat_idx = 47
other_repeat_idx = 12
ngram_repeat = 3
for batch_sz in [1, 3, 13]:
samp = RandomSampling(
0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat, set(),
False, 30, 1., 5, torch.randint(0, 30, (batch_sz,)))
for i in range(ngram_repeat + 4):
word_probs = torch.full(
(batch_sz, n_words), -float('inf'))
# predict the same thing in batch 0 and 7 every i
word_probs[0, repeat_idx] = 0
if batch_sz > 7:
word_probs[7, other_repeat_idx] = 0
# push around what the other batches predict
word_probs[1:7, repeat_idx + i] = 0
if batch_sz > 7:
word_probs[8:, repeat_idx + i] = 0
attns = torch.randn(1, batch_sz, 53)
samp.advance(word_probs, attns)
if i <= ngram_repeat:
self.assertFalse(
samp.topk_scores.eq(
self.BLOCKED_SCORE).any())
else:
# now batch 0 and 7 die
self.assertTrue(samp.topk_scores[0].eq(self.BLOCKED_SCORE))
if batch_sz > 7:
self.assertTrue(samp.topk_scores[7].eq(
self.BLOCKED_SCORE))
self.assertFalse(
samp.topk_scores[1:7].eq(
self.BLOCKED_SCORE).any())
if batch_sz > 7:
self.assertFalse(
samp.topk_scores[8:].eq(
self.BLOCKED_SCORE).any())
def test_repeating_excluded_index_does_not_die(self):
# batch 0 will repeat excluded idx, batch 1 will repeat
n_words = 100
repeat_idx = 47 # will be repeated and should be blocked
repeat_idx_ignored = 7 # will be repeated and should not be blocked
ngram_repeat = 3
for batch_sz in [1, 3, 17]:
samp = RandomSampling(
0, 1, 2, batch_sz, torch.device("cpu"), 0, ngram_repeat,
{repeat_idx_ignored}, False, 30, 1., 5,
torch.randint(0, 30, (batch_sz,)))
for i in range(ngram_repeat + 4):
word_probs = torch.full(
(batch_sz, n_words), -float('inf'))
word_probs[0, repeat_idx_ignored] = 0
if batch_sz > 1:
word_probs[1, repeat_idx] = 0
word_probs[2:, repeat_idx + i] = 0
attns = torch.randn(1, batch_sz, 53)
samp.advance(word_probs, attns)
if i <= ngram_repeat:
self.assertFalse(samp.topk_scores.eq(
self.BLOCKED_SCORE).any())
else:
# now batch 1 dies
self.assertFalse(samp.topk_scores[0].eq(
self.BLOCKED_SCORE).any())
if batch_sz > 1:
self.assertTrue(samp.topk_scores[1].eq(
self.BLOCKED_SCORE).all())
self.assertFalse(samp.topk_scores[2:].eq(
self.BLOCKED_SCORE).any())
def test_doesnt_predict_eos_if_shorter_than_min_len(self): def test_doesnt_predict_eos_if_shorter_than_min_len(self):
# batch 0 will always predict EOS. The other batches will predict # batch 0 will always predict EOS. The other batches will predict
# non-eos scores. # non-eos scores.
@ -120,9 +22,10 @@ class TestRandomSampling(unittest.TestCase):
min_length = 5 min_length = 5
eos_idx = 2 eos_idx = 2
lengths = torch.randint(0, 30, (batch_sz,)) lengths = torch.randint(0, 30, (batch_sz,))
samp = RandomSampling( samp = GreedySearch(
0, 1, 2, batch_sz, torch.device("cpu"), min_length, 0, 1, 2, batch_sz, min_length,
False, set(), False, 30, 1., 1, lengths) False, set(), False, 30, 1., 1)
samp.initialize(torch.zeros(1), lengths)
all_attns = [] all_attns = []
for i in range(min_length + 4): for i in range(min_length + 4):
word_probs = torch.full( word_probs = torch.full(
@ -160,10 +63,10 @@ class TestRandomSampling(unittest.TestCase):
[6., 1.]), dim=0) [6., 1.]), dim=0)
eos_idx = 2 eos_idx = 2
lengths = torch.randint(0, 30, (batch_sz,)) lengths = torch.randint(0, 30, (batch_sz,))
samp = RandomSampling( samp = GreedySearch(
0, 1, 2, batch_sz, torch.device("cpu"), 0, 0, 1, 2, batch_sz, 0,
False, set(), False, 30, temp, 1, lengths) False, set(), False, 30, temp, 1)
samp.initialize(torch.zeros(1), lengths)
# initial step # initial step
i = 0 i = 0
word_probs = torch.full( word_probs = torch.full(
@ -232,10 +135,10 @@ class TestRandomSampling(unittest.TestCase):
[6., 1.]), dim=0) [6., 1.]), dim=0)
eos_idx = 2 eos_idx = 2
lengths = torch.randint(0, 30, (batch_sz,)) lengths = torch.randint(0, 30, (batch_sz,))
samp = RandomSampling( samp = GreedySearch(
0, 1, 2, batch_sz, torch.device("cpu"), 0, 0, 1, 2, batch_sz, 0,
False, set(), False, 30, temp, 2, lengths) False, set(), False, 30, temp, 2)
samp.initialize(torch.zeros(1), lengths)
# initial step # initial step
i = 0 i = 0
for _ in range(100): for _ in range(100):

0
opennmt/onmt/tests/test_models.sh Normal file → Executable file
Просмотреть файл

Просмотреть файл

@ -12,7 +12,7 @@ import codecs
import onmt import onmt
import onmt.inputters import onmt.inputters
import onmt.opts import onmt.opts
import preprocess import onmt.bin.preprocess as preprocess
parser = configargparse.ArgumentParser(description='preprocess.py') parser = configargparse.ArgumentParser(description='preprocess.py')
@ -49,11 +49,12 @@ class TestData(unittest.TestCase):
src_reader = onmt.inputters.str2reader[opt.data_type].from_opt(opt) src_reader = onmt.inputters.str2reader[opt.data_type].from_opt(opt)
tgt_reader = onmt.inputters.str2reader["text"].from_opt(opt) tgt_reader = onmt.inputters.str2reader["text"].from_opt(opt)
align_reader = onmt.inputters.str2reader["text"].from_opt(opt)
preprocess.build_save_dataset( preprocess.build_save_dataset(
'train', fields, src_reader, tgt_reader, opt) 'train', fields, src_reader, tgt_reader, align_reader, opt)
preprocess.build_save_dataset( preprocess.build_save_dataset(
'valid', fields, src_reader, tgt_reader, opt) 'valid', fields, src_reader, tgt_reader, align_reader, opt)
# Remove the generated *pt files. # Remove the generated *pt files.
for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'): for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'):

Просмотреть файл

@ -112,27 +112,24 @@ class TestServerModel(unittest.TestCase):
sm = ServerModel(opt, model_id, model_root=model_root, load=True) sm = ServerModel(opt, model_id, model_root=model_root, load=True)
inp = [{"src": "hello how are you today"}, inp = [{"src": "hello how are you today"},
{"src": "good morning to you ."}] {"src": "good morning to you ."}]
results, scores, n_best, time = sm.run(inp) results, scores, n_best, time, aligns = sm.run(inp)
self.assertIsInstance(results, list) self.assertIsInstance(results, list)
for sentence_string in results: for sentence_string in results:
self.assertIsInstance(sentence_string, string_types) self.assertIsInstance(sentence_string, string_types)
self.assertIsInstance(scores, list) self.assertIsInstance(scores, list)
for elem in scores: for elem in scores:
self.assertIsInstance(elem, float) self.assertIsInstance(elem, float)
self.assertIsInstance(aligns, list)
for align_list in aligns:
for align_string in align_list:
if align_string is not None:
self.assertIsInstance(align_string, string_types)
self.assertEqual(len(results), len(scores)) self.assertEqual(len(results), len(scores))
self.assertEqual(len(scores), len(inp)) self.assertEqual(len(scores), len(inp) * n_best)
self.assertEqual(n_best, 1)
self.assertEqual(len(time), 1) self.assertEqual(len(time), 1)
self.assertIsInstance(time, dict) self.assertIsInstance(time, dict)
self.assertIn("translation", time) self.assertIn("translation", time)
def test_nbest_init_fails(self):
model_id = 0
opt = {"models": ["test_model.pt"], "n_best": 2}
model_root = TEST_DIR
with self.assertRaises(ValueError):
ServerModel(opt, model_id, model_root=model_root, load=True)
class TestTranslationServer(unittest.TestCase): class TestTranslationServer(unittest.TestCase):
# this could be considered an integration test because it touches # this could be considered an integration test because it touches

7
opennmt/onmt/train_single.py Normal file → Executable file
Просмотреть файл

@ -4,7 +4,7 @@ import os
import torch import torch
from onmt.inputters.inputter import build_dataset_iter, \ from onmt.inputters.inputter import build_dataset_iter, patch_fields, \
load_old_vocab, old_style_vocab, build_dataset_iter_multiple load_old_vocab, old_style_vocab, build_dataset_iter_multiple
from onmt.model_builder import build_model from onmt.model_builder import build_model
from onmt.utils.optimizers import Optimizer from onmt.utils.optimizers import Optimizer
@ -69,6 +69,9 @@ def main(opt, device_id, batch_queue=None, semaphore=None):
else: else:
fields = vocab fields = vocab
# patch for fields that may be missing in old data/model
patch_fields(opt, fields)
# Report src and tgt vocab sizes, including for features # Report src and tgt vocab sizes, including for features
for side in ['src', 'tgt']: for side in ['src', 'tgt']:
f = fields[side] f = fields[side]
@ -142,5 +145,5 @@ def main(opt, device_id, batch_queue=None, semaphore=None):
valid_iter=valid_iter, valid_iter=valid_iter,
valid_steps=opt.valid_steps) valid_steps=opt.valid_steps)
if opt.tensorboard: if trainer.report_manager.tensorboard_writer is not None:
trainer.report_manager.tensorboard_writer.close() trainer.report_manager.tensorboard_writer.close()

Просмотреть файл

@ -9,7 +9,6 @@
users of this library) for the strategy things we do. users of this library) for the strategy things we do.
""" """
from copy import deepcopy
import torch import torch
import traceback import traceback
@ -58,19 +57,39 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None):
opt.early_stopping, scorers=onmt.utils.scorers_from_opts(opt)) \ opt.early_stopping, scorers=onmt.utils.scorers_from_opts(opt)) \
if opt.early_stopping > 0 else None if opt.early_stopping > 0 else None
report_manager = onmt.utils.build_report_manager(opt) source_noise = None
if len(opt.src_noise) > 0:
src_field = dict(fields)["src"].base_field
corpus_id_field = dict(fields).get("corpus_id", None)
if corpus_id_field is not None:
ids_to_noise = corpus_id_field.numericalize(opt.data_to_noise)
else:
ids_to_noise = None
source_noise = onmt.modules.source_noise.MultiNoise(
opt.src_noise,
opt.src_noise_prob,
ids_to_noise=ids_to_noise,
pad_idx=src_field.pad_token,
end_of_sentence_mask=src_field.end_of_sentence_mask,
word_start_mask=src_field.word_start_mask,
device_id=device_id
)
report_manager = onmt.utils.build_report_manager(opt, gpu_rank)
trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size, trainer = onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size,
shard_size, norm_method, shard_size, norm_method,
accum_count, accum_steps, accum_count, accum_steps,
n_gpu, gpu_rank, n_gpu, gpu_rank,
gpu_verbose_level, report_manager, gpu_verbose_level, report_manager,
with_align=True if opt.lambda_align > 0 else False,
model_saver=model_saver if gpu_rank == 0 else None, model_saver=model_saver if gpu_rank == 0 else None,
average_decay=average_decay, average_decay=average_decay,
average_every=average_every, average_every=average_every,
model_dtype=opt.model_dtype, model_dtype=opt.model_dtype,
earlystopper=earlystopper, earlystopper=earlystopper,
dropout=dropout, dropout=dropout,
dropout_steps=dropout_steps) dropout_steps=dropout_steps,
source_noise=source_noise)
return trainer return trainer
@ -104,10 +123,11 @@ class Trainer(object):
trunc_size=0, shard_size=32, trunc_size=0, shard_size=32,
norm_method="sents", accum_count=[1], norm_method="sents", accum_count=[1],
accum_steps=[0], accum_steps=[0],
n_gpu=1, gpu_rank=1, n_gpu=1, gpu_rank=1, gpu_verbose_level=0,
gpu_verbose_level=0, report_manager=None, model_saver=None, report_manager=None, with_align=False, model_saver=None,
average_decay=0, average_every=1, model_dtype='fp32', average_decay=0, average_every=1, model_dtype='fp32',
earlystopper=None, dropout=[0.3], dropout_steps=[0]): earlystopper=None, dropout=[0.3], dropout_steps=[0],
source_noise=None):
# Basic attributes. # Basic attributes.
self.model = model self.model = model
self.train_loss = train_loss self.train_loss = train_loss
@ -123,6 +143,7 @@ class Trainer(object):
self.gpu_rank = gpu_rank self.gpu_rank = gpu_rank
self.gpu_verbose_level = gpu_verbose_level self.gpu_verbose_level = gpu_verbose_level
self.report_manager = report_manager self.report_manager = report_manager
self.with_align = with_align
self.model_saver = model_saver self.model_saver = model_saver
self.average_decay = average_decay self.average_decay = average_decay
self.moving_average = None self.moving_average = None
@ -131,6 +152,7 @@ class Trainer(object):
self.earlystopper = earlystopper self.earlystopper = earlystopper
self.dropout = dropout self.dropout = dropout
self.dropout_steps = dropout_steps self.dropout_steps = dropout_steps
self.source_noise = source_noise
for i in range(len(self.accum_count_l)): for i in range(len(self.accum_count_l)):
assert self.accum_count_l[i] > 0 assert self.accum_count_l[i] > 0
@ -290,13 +312,16 @@ class Trainer(object):
Returns: Returns:
:obj:`nmt.Statistics`: validation loss statistics :obj:`nmt.Statistics`: validation loss statistics
""" """
valid_model = self.model
if moving_average: if moving_average:
valid_model = deepcopy(self.model) # swap model params w/ moving average
# (and keep the original parameters)
model_params_data = []
for avg, param in zip(self.moving_average, for avg, param in zip(self.moving_average,
valid_model.parameters()): valid_model.parameters()):
param.data = avg.data model_params_data.append(param.data)
else: param.data = avg.data.half() if self.optim._fp16 == "legacy" \
valid_model = self.model else avg.data
# Set model in validating mode. # Set model in validating mode.
valid_model.eval() valid_model.eval()
@ -310,17 +335,19 @@ class Trainer(object):
tgt = batch.tgt tgt = batch.tgt
# F-prop through the model. # F-prop through the model.
outputs, attns = valid_model(src, tgt, src_lengths) outputs, attns = valid_model(src, tgt, src_lengths,
with_align=self.with_align)
# Compute loss. # Compute loss.
_, batch_stats = self.valid_loss(batch, outputs, attns) _, batch_stats = self.valid_loss(batch, outputs, attns)
# Update statistics. # Update statistics.
stats.update(batch_stats) stats.update(batch_stats)
if moving_average: if moving_average:
del valid_model for param_data, param in zip(model_params_data,
else: self.model.parameters()):
param.data = param_data
# Set model back to training mode. # Set model back to training mode.
valid_model.train() valid_model.train()
@ -339,6 +366,8 @@ class Trainer(object):
else: else:
trunc_size = target_size trunc_size = target_size
batch = self.maybe_noise_source(batch)
src, src_lengths = batch.src if isinstance(batch.src, tuple) \ src, src_lengths = batch.src if isinstance(batch.src, tuple) \
else (batch.src, None) else (batch.src, None)
if src_lengths is not None: if src_lengths is not None:
@ -354,7 +383,9 @@ class Trainer(object):
# 2. F-prop all but generator. # 2. F-prop all but generator.
if self.accum_count == 1: if self.accum_count == 1:
self.optim.zero_grad() self.optim.zero_grad()
outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt)
outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt,
with_align=self.with_align)
bptt = True bptt = True
# 3. Compute loss. # 3. Compute loss.
@ -454,3 +485,8 @@ class Trainer(object):
return self.report_manager.report_step( return self.report_manager.report_step(
learning_rate, step, train_stats=train_stats, learning_rate, step, train_stats=train_stats,
valid_stats=valid_stats) valid_stats=valid_stats)
def maybe_noise_source(self, batch):
if self.source_noise is not None:
return self.source_noise(batch)
return batch

Просмотреть файл

@ -1,15 +1,14 @@
""" Modules for translation """ """ Modules for translation """
from onmt.translate.translator import Translator from onmt.translate.translator import Translator
from onmt.translate.translation import Translation, TranslationBuilder from onmt.translate.translation import Translation, TranslationBuilder
from onmt.translate.beam import Beam, GNMTGlobalScorer from onmt.translate.beam_search import BeamSearch, GNMTGlobalScorer
from onmt.translate.beam_search import BeamSearch
from onmt.translate.decode_strategy import DecodeStrategy from onmt.translate.decode_strategy import DecodeStrategy
from onmt.translate.random_sampling import RandomSampling from onmt.translate.greedy_search import GreedySearch
from onmt.translate.penalties import PenaltyBuilder from onmt.translate.penalties import PenaltyBuilder
from onmt.translate.translation_server import TranslationServer, \ from onmt.translate.translation_server import TranslationServer, \
ServerModelError ServerModelError
__all__ = ['Translator', 'Translation', 'Beam', 'BeamSearch', __all__ = ['Translator', 'Translation', 'BeamSearch',
'GNMTGlobalScorer', 'TranslationBuilder', 'GNMTGlobalScorer', 'TranslationBuilder',
'PenaltyBuilder', 'TranslationServer', 'ServerModelError', 'PenaltyBuilder', 'TranslationServer', 'ServerModelError',
"DecodeStrategy", "RandomSampling"] "DecodeStrategy", "GreedySearch"]

Просмотреть файл

@ -1,293 +0,0 @@
from __future__ import division
import torch
from onmt.translate import penalties
import warnings
class Beam(object):
"""Class for managing the internals of the beam search process.
Takes care of beams, back pointers, and scores.
Args:
size (int): Number of beams to use.
pad (int): Magic integer in output vocab.
bos (int): Magic integer in output vocab.
eos (int): Magic integer in output vocab.
n_best (int): Don't stop until at least this many beams have
reached EOS.
cuda (bool): use gpu
global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance.
min_length (int): Shortest acceptable generation, not counting
begin-of-sentence or end-of-sentence.
stepwise_penalty (bool): Apply coverage penalty at every step.
block_ngram_repeat (int): Block beams where
``block_ngram_repeat``-grams repeat.
exclusion_tokens (set[int]): If a gram contains any of these
token indices, it may repeat.
"""
def __init__(self, size, pad, bos, eos,
n_best=1, cuda=False,
global_scorer=None,
min_length=0,
stepwise_penalty=False,
block_ngram_repeat=0,
exclusion_tokens=set()):
self.size = size
self.tt = torch.cuda if cuda else torch
# The score for each translation on the beam.
self.scores = self.tt.FloatTensor(size).zero_()
self.all_scores = []
# The backpointers at each time-step.
self.prev_ks = []
# The outputs at each time-step.
self.next_ys = [self.tt.LongTensor(size)
.fill_(pad)]
self.next_ys[0][0] = bos
# Has EOS topped the beam yet.
self._eos = eos
self.eos_top = False
# The attentions (matrix) for each time.
self.attn = []
# Time and k pair for finished.
self.finished = []
self.n_best = n_best
# Information for global scoring.
self.global_scorer = global_scorer
self.global_state = {}
# Minimum prediction length
self.min_length = min_length
# Apply Penalty at every step
self.stepwise_penalty = stepwise_penalty
self.block_ngram_repeat = block_ngram_repeat
self.exclusion_tokens = exclusion_tokens
@property
def current_predictions(self):
return self.next_ys[-1]
@property
def current_origin(self):
"""Get the backpointers for the current timestep."""
return self.prev_ks[-1]
def advance(self, word_probs, attn_out):
"""
Given prob over words for every last beam `wordLk` and attention
`attn_out`: Compute and update the beam search.
Args:
word_probs (FloatTensor): probs of advancing from the last step
``(K, words)``
attn_out (FloatTensor): attention at the last step
Returns:
bool: True if beam search is complete.
"""
num_words = word_probs.size(1)
if self.stepwise_penalty:
self.global_scorer.update_score(self, attn_out)
# force the output to be longer than self.min_length
cur_len = len(self.next_ys)
if cur_len <= self.min_length:
# assumes there are len(word_probs) predictions OTHER
# than EOS that are greater than -1e20
for k in range(len(word_probs)):
word_probs[k][self._eos] = -1e20
# Sum the previous scores.
if len(self.prev_ks) > 0:
beam_scores = word_probs + self.scores.unsqueeze(1)
# Don't let EOS have children.
for i in range(self.next_ys[-1].size(0)):
if self.next_ys[-1][i] == self._eos:
beam_scores[i] = -1e20
# Block ngram repeats
if self.block_ngram_repeat > 0:
le = len(self.next_ys)
for j in range(self.next_ys[-1].size(0)):
hyp, _ = self.get_hyp(le - 1, j)
ngrams = set()
fail = False
gram = []
for i in range(le - 1):
# Last n tokens, n = block_ngram_repeat
gram = (gram +
[hyp[i].item()])[-self.block_ngram_repeat:]
# Skip the blocking if it is in the exclusion list
if set(gram) & self.exclusion_tokens:
continue
if tuple(gram) in ngrams:
fail = True
ngrams.add(tuple(gram))
if fail:
beam_scores[j] = -10e20
else:
beam_scores = word_probs[0]
flat_beam_scores = beam_scores.view(-1)
best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0,
True, True)
self.all_scores.append(self.scores)
self.scores = best_scores
# best_scores_id is flattened beam x word array, so calculate which
# word and beam each score came from
prev_k = best_scores_id / num_words
self.prev_ks.append(prev_k)
self.next_ys.append((best_scores_id - prev_k * num_words))
self.attn.append(attn_out.index_select(0, prev_k))
self.global_scorer.update_global_state(self)
for i in range(self.next_ys[-1].size(0)):
if self.next_ys[-1][i] == self._eos:
global_scores = self.global_scorer.score(self, self.scores)
s = global_scores[i]
self.finished.append((s, len(self.next_ys) - 1, i))
# End condition is when top-of-beam is EOS and no global score.
if self.next_ys[-1][0] == self._eos:
self.all_scores.append(self.scores)
self.eos_top = True
@property
def done(self):
return self.eos_top and len(self.finished) >= self.n_best
def sort_finished(self, minimum=None):
if minimum is not None:
i = 0
# Add from beam until we have minimum outputs.
while len(self.finished) < minimum:
global_scores = self.global_scorer.score(self, self.scores)
s = global_scores[i]
self.finished.append((s, len(self.next_ys) - 1, i))
i += 1
self.finished.sort(key=lambda a: -a[0])
scores = [sc for sc, _, _ in self.finished]
ks = [(t, k) for _, t, k in self.finished]
return scores, ks
def get_hyp(self, timestep, k):
"""Walk back to construct the full hypothesis."""
hyp, attn = [], []
for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1):
hyp.append(self.next_ys[j + 1][k])
attn.append(self.attn[j][k])
k = self.prev_ks[j][k]
return hyp[::-1], torch.stack(attn[::-1])
class GNMTGlobalScorer(object):
"""NMT re-ranking.
Args:
alpha (float): Length parameter.
beta (float): Coverage parameter.
length_penalty (str): Length penalty strategy.
coverage_penalty (str): Coverage penalty strategy.
Attributes:
alpha (float): See above.
beta (float): See above.
length_penalty (callable): See :class:`penalties.PenaltyBuilder`.
coverage_penalty (callable): See :class:`penalties.PenaltyBuilder`.
has_cov_pen (bool): See :class:`penalties.PenaltyBuilder`.
has_len_pen (bool): See :class:`penalties.PenaltyBuilder`.
"""
@classmethod
def from_opt(cls, opt):
return cls(
opt.alpha,
opt.beta,
opt.length_penalty,
opt.coverage_penalty)
def __init__(self, alpha, beta, length_penalty, coverage_penalty):
self._validate(alpha, beta, length_penalty, coverage_penalty)
self.alpha = alpha
self.beta = beta
penalty_builder = penalties.PenaltyBuilder(coverage_penalty,
length_penalty)
self.has_cov_pen = penalty_builder.has_cov_pen
# Term will be subtracted from probability
self.cov_penalty = penalty_builder.coverage_penalty
self.has_len_pen = penalty_builder.has_len_pen
# Probability will be divided by this
self.length_penalty = penalty_builder.length_penalty
@classmethod
def _validate(cls, alpha, beta, length_penalty, coverage_penalty):
# these warnings indicate that either the alpha/beta
# forces a penalty to be a no-op, or a penalty is a no-op but
# the alpha/beta would suggest otherwise.
if length_penalty is None or length_penalty == "none":
if alpha != 0:
warnings.warn("Non-default `alpha` with no length penalty. "
"`alpha` has no effect.")
else:
# using some length penalty
if length_penalty == "wu" and alpha == 0.:
warnings.warn("Using length penalty Wu with alpha==0 "
"is equivalent to using length penalty none.")
if coverage_penalty is None or coverage_penalty == "none":
if beta != 0:
warnings.warn("Non-default `beta` with no coverage penalty. "
"`beta` has no effect.")
else:
# using some coverage penalty
if beta == 0.:
warnings.warn("Non-default coverage penalty with beta==0 "
"is equivalent to using coverage penalty none.")
def score(self, beam, logprobs):
"""Rescore a prediction based on penalty functions."""
len_pen = self.length_penalty(len(beam.next_ys), self.alpha)
normalized_probs = logprobs / len_pen
if not beam.stepwise_penalty:
penalty = self.cov_penalty(beam.global_state["coverage"],
self.beta)
normalized_probs -= penalty
return normalized_probs
def update_score(self, beam, attn):
"""Update scores of a Beam that is not finished."""
if "prev_penalty" in beam.global_state.keys():
beam.scores.add_(beam.global_state["prev_penalty"])
penalty = self.cov_penalty(beam.global_state["coverage"] + attn,
self.beta)
beam.scores.sub_(penalty)
def update_global_state(self, beam):
"""Keeps the coverage vector as sum of attentions."""
if len(beam.prev_ks) == 1:
beam.global_state["prev_penalty"] = beam.scores.clone().fill_(0.0)
beam.global_state["coverage"] = beam.attn[-1]
self.cov_total = beam.attn[-1].sum(1)
else:
self.cov_total += torch.min(beam.attn[-1],
beam.global_state['coverage']).sum(1)
beam.global_state["coverage"] = beam.global_state["coverage"] \
.index_select(0, beam.prev_ks[-1]).add(beam.attn[-1])
prev_penalty = self.cov_penalty(beam.global_state["coverage"],
self.beta)
beam.global_state["prev_penalty"] = prev_penalty

Просмотреть файл

@ -1,6 +1,9 @@
import torch import torch
from onmt.translate import penalties
from onmt.translate.decode_strategy import DecodeStrategy from onmt.translate.decode_strategy import DecodeStrategy
from onmt.utils.misc import tile
import warnings
class BeamSearch(DecodeStrategy): class BeamSearch(DecodeStrategy):
@ -19,15 +22,12 @@ class BeamSearch(DecodeStrategy):
eos (int): See base. eos (int): See base.
n_best (int): Don't stop until at least this many beams have n_best (int): Don't stop until at least this many beams have
reached EOS. reached EOS.
mb_device (torch.device or str): See base ``device``.
global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance. global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance.
min_length (int): See base. min_length (int): See base.
max_length (int): See base. max_length (int): See base.
return_attention (bool): See base. return_attention (bool): See base.
block_ngram_repeat (int): See base. block_ngram_repeat (int): See base.
exclusion_tokens (set[int]): See base. exclusion_tokens (set[int]): See base.
memory_lengths (LongTensor): Lengths of encodings. Used for
masking attentions.
Attributes: Attributes:
top_beam_finished (ByteTensor): Shape ``(B,)``. top_beam_finished (ByteTensor): Shape ``(B,)``.
@ -36,6 +36,8 @@ class BeamSearch(DecodeStrategy):
alive_seq (LongTensor): See base. alive_seq (LongTensor): See base.
topk_log_probs (FloatTensor): Shape ``(B x beam_size,)``. These topk_log_probs (FloatTensor): Shape ``(B x beam_size,)``. These
are the scores used for the topk operation. are the scores used for the topk operation.
memory_lengths (LongTensor): Lengths of encodings. Used for
masking attentions.
select_indices (LongTensor or NoneType): Shape select_indices (LongTensor or NoneType): Shape
``(B x beam_size,)``. This is just a flat view of the ``(B x beam_size,)``. This is just a flat view of the
``_batch_index``. ``_batch_index``.
@ -53,19 +55,18 @@ class BeamSearch(DecodeStrategy):
of score (float), sequence (long), and attention (float or None). of score (float), sequence (long), and attention (float or None).
""" """
def __init__(self, beam_size, batch_size, pad, bos, eos, n_best, mb_device, def __init__(self, beam_size, batch_size, pad, bos, eos, n_best,
global_scorer, min_length, max_length, return_attention, global_scorer, min_length, max_length, return_attention,
block_ngram_repeat, exclusion_tokens, memory_lengths, block_ngram_repeat, exclusion_tokens,
stepwise_penalty, ratio): stepwise_penalty, ratio):
super(BeamSearch, self).__init__( super(BeamSearch, self).__init__(
pad, bos, eos, batch_size, mb_device, beam_size, min_length, pad, bos, eos, batch_size, beam_size, min_length,
block_ngram_repeat, exclusion_tokens, return_attention, block_ngram_repeat, exclusion_tokens, return_attention,
max_length) max_length)
# beam parameters # beam parameters
self.global_scorer = global_scorer self.global_scorer = global_scorer
self.beam_size = beam_size self.beam_size = beam_size
self.n_best = n_best self.n_best = n_best
self.batch_size = batch_size
self.ratio = ratio self.ratio = ratio
# result caching # result caching
@ -73,26 +74,14 @@ class BeamSearch(DecodeStrategy):
# beam state # beam state
self.top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8) self.top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8)
self.best_scores = torch.full([batch_size], -1e10, dtype=torch.float, # BoolTensor was introduced in pytorch 1.2
device=mb_device) try:
self.top_beam_finished = self.top_beam_finished.bool()
except AttributeError:
pass
self._batch_offset = torch.arange(batch_size, dtype=torch.long) self._batch_offset = torch.arange(batch_size, dtype=torch.long)
self._beam_offset = torch.arange(
0, batch_size * beam_size, step=beam_size, dtype=torch.long,
device=mb_device)
self.topk_log_probs = torch.tensor(
[0.0] + [float("-inf")] * (beam_size - 1), device=mb_device
).repeat(batch_size)
self.select_indices = None
self._memory_lengths = memory_lengths
# buffers for the topk scores and 'backpointer' self.select_indices = None
self.topk_scores = torch.empty((batch_size, beam_size),
dtype=torch.float, device=mb_device)
self.topk_ids = torch.empty((batch_size, beam_size), dtype=torch.long,
device=mb_device)
self._batch_index = torch.empty([batch_size, beam_size],
dtype=torch.long, device=mb_device)
self.done = False self.done = False
# "global state" of the old beam # "global state" of the old beam
self._prev_penalty = None self._prev_penalty = None
@ -104,20 +93,60 @@ class BeamSearch(DecodeStrategy):
not stepwise_penalty and self.global_scorer.has_cov_pen) not stepwise_penalty and self.global_scorer.has_cov_pen)
self._cov_pen = self.global_scorer.has_cov_pen self._cov_pen = self.global_scorer.has_cov_pen
def initialize(self, memory_bank, src_lengths, src_map=None, device=None):
"""Initialize for decoding.
Repeat src objects `beam_size` times.
"""
def fn_map_state(state, dim):
return tile(state, self.beam_size, dim=dim)
if isinstance(memory_bank, tuple):
memory_bank = tuple(tile(x, self.beam_size, dim=1)
for x in memory_bank)
mb_device = memory_bank[0].device
else:
memory_bank = tile(memory_bank, self.beam_size, dim=1)
mb_device = memory_bank.device
if src_map is not None:
src_map = tile(src_map, self.beam_size, dim=1)
if device is None:
device = mb_device
self.memory_lengths = tile(src_lengths, self.beam_size)
super(BeamSearch, self).initialize(
memory_bank, self.memory_lengths, src_map, device)
self.best_scores = torch.full(
[self.batch_size], -1e10, dtype=torch.float, device=device)
self._beam_offset = torch.arange(
0, self.batch_size * self.beam_size, step=self.beam_size,
dtype=torch.long, device=device)
self.topk_log_probs = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1), device=device
).repeat(self.batch_size)
# buffers for the topk scores and 'backpointer'
self.topk_scores = torch.empty((self.batch_size, self.beam_size),
dtype=torch.float, device=device)
self.topk_ids = torch.empty((self.batch_size, self.beam_size),
dtype=torch.long, device=device)
self._batch_index = torch.empty([self.batch_size, self.beam_size],
dtype=torch.long, device=device)
return fn_map_state, memory_bank, self.memory_lengths, src_map
@property @property
def current_predictions(self): def current_predictions(self):
return self.alive_seq[:, -1] return self.alive_seq[:, -1]
@property
def current_origin(self):
return self.select_indices
@property @property
def current_backptr(self): def current_backptr(self):
# for testing # for testing
return self.select_indices.view(self.batch_size, self.beam_size)\ return self.select_indices.view(self.batch_size, self.beam_size)\
.fmod(self.beam_size) .fmod(self.beam_size)
@property
def batch_offset(self):
return self._batch_offset
def advance(self, log_probs, attn, partial=None, partialf=None): def advance(self, log_probs, attn, partial=None, partialf=None):
vocab_size = log_probs.size(-1) vocab_size = log_probs.size(-1)
@ -137,16 +166,19 @@ class BeamSearch(DecodeStrategy):
# Multiply probs by the beam probability. # Multiply probs by the beam probability.
log_probs += self.topk_log_probs.view(_B * self.beam_size, 1) log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)
self.block_ngram_repeats(log_probs)
# if the sequence ends now, then the penalty is the current # if the sequence ends now, then the penalty is the current
# length + 1, to include the EOS token # length + 1, to include the EOS token
length_penalty = self.global_scorer.length_penalty( length_penalty = self.global_scorer.length_penalty(
step + 1, alpha=self.global_scorer.alpha) step + 1, alpha=self.global_scorer.alpha)
# Flatten probs into a list of possibilities.
curr_scores = log_probs / length_penalty curr_scores = log_probs / length_penalty
# Avoid any direction that would repeat unwanted ngrams
self.block_ngram_repeats(curr_scores)
# Flatten probs into a list of possibilities.
curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size) curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size)
if partialf and step and step == len(partial) + 1: if partialf and step and step == len(partial) + 1:
# proxyf = [] # proxyf = []
# for i in range(self.beam_size): # for i in range(self.beam_size):
@ -169,7 +201,6 @@ class BeamSearch(DecodeStrategy):
torch.topk(curr_scores, self.beam_size, dim=-1, torch.topk(curr_scores, self.beam_size, dim=-1,
out=(self.topk_scores, self.topk_ids)) out=(self.topk_scores, self.topk_ids))
# if partialf and step and step == len(partial) + 1: # if partialf and step and step == len(partial) + 1:
# print(self.topk_scores) # print(self.topk_scores)
# Recover log probs. # Recover log probs.
@ -181,18 +212,18 @@ class BeamSearch(DecodeStrategy):
torch.div(self.topk_ids, vocab_size, out=self._batch_index) torch.div(self.topk_ids, vocab_size, out=self._batch_index)
self._batch_index += self._beam_offset[:_B].unsqueeze(1) self._batch_index += self._beam_offset[:_B].unsqueeze(1)
self.select_indices = self._batch_index.view(_B * self.beam_size) self.select_indices = self._batch_index.view(_B * self.beam_size)
self.topk_ids.fmod_(vocab_size) # resolve true word ids self.topk_ids.fmod_(vocab_size) # resolve true word ids
# Append last prediction. # Append last prediction.
if partial and step and step <= len(partial): if partial and step and step <= len(partial):
self.topk_ids = torch.full([self.batch_size * self.beam_size, 1], partial[step-1], dtype=torch.long, device=torch.device('cpu')) self.topk_ids = torch.full([self.batch_size * self.beam_size, 1], partial[step-1], dtype=torch.long, device=torch.device('cpu'))
self.select_indices = torch.tensor([0, 0, 0, 0, 0]) self.select_indices = torch.tensor([0, 0, 0, 0, 0])
self.alive_seq = torch.cat( self.alive_seq = torch.cat(
[self.alive_seq.index_select(0, self.select_indices), [self.alive_seq.index_select(0, self.select_indices),
self.topk_ids.view(_B * self.beam_size, 1)], -1) self.topk_ids.view(_B * self.beam_size, 1)], -1)
self.maybe_update_forbidden_tokens()
if self.return_attention or self._cov_pen: if self.return_attention or self._cov_pen:
current_attn = attn.index_select(1, self.select_indices) current_attn = attn.index_select(1, self.select_indices)
if step == 1: if step == 1:
@ -219,10 +250,9 @@ class BeamSearch(DecodeStrategy):
cov_penalty = self.global_scorer.cov_penalty( cov_penalty = self.global_scorer.cov_penalty(
self._coverage, self._coverage,
beta=self.global_scorer.beta) beta=self.global_scorer.beta)
self.topk_scores -= cov_penalty.view(_B, self.beam_size) self.topk_scores -= cov_penalty.view(_B, self.beam_size).float()
self.is_finished = self.topk_ids.eq(self.eos) self.is_finished = self.topk_ids.eq(self.eos)
# print(self.is_finished, self.topk_ids)
self.ensure_max_length() self.ensure_max_length()
def update_finished(self): def update_finished(self):
@ -240,11 +270,11 @@ class BeamSearch(DecodeStrategy):
step - 1, _B_old, self.beam_size, self.alive_attn.size(-1)) step - 1, _B_old, self.beam_size, self.alive_attn.size(-1))
if self.alive_attn is not None else None) if self.alive_attn is not None else None)
non_finished_batch = [] non_finished_batch = []
for i in range(self.is_finished.size(0)): for i in range(self.is_finished.size(0)): # Batch level
b = self._batch_offset[i] b = self._batch_offset[i]
finished_hyp = self.is_finished[i].nonzero().view(-1) finished_hyp = self.is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch. # Store finished hypotheses for this batch.
for j in finished_hyp: for j in finished_hyp: # Beam level: finished beam j in batch i
if self.ratio > 0: if self.ratio > 0:
s = self.topk_scores[i, j] / (step + 1) s = self.topk_scores[i, j] / (step + 1)
if self.best_scores[b] < s: if self.best_scores[b] < s:
@ -252,12 +282,12 @@ class BeamSearch(DecodeStrategy):
self.hypotheses[b].append(( self.hypotheses[b].append((
self.topk_scores[i, j], self.topk_scores[i, j],
predictions[i, j, 1:], # Ignore start_token. predictions[i, j, 1:], # Ignore start_token.
attention[:, i, j, :self._memory_lengths[i]] attention[:, i, j, :self.memory_lengths[i]]
if attention is not None else None)) if attention is not None else None))
# End condition is the top beam finished and we can return # End condition is the top beam finished and we can return
# n_best hypotheses. # n_best hypotheses.
if self.ratio > 0: if self.ratio > 0:
pred_len = self._memory_lengths[i] * self.ratio pred_len = self.memory_lengths[i] * self.ratio
finish_flag = ((self.topk_scores[i, 0] / pred_len) finish_flag = ((self.topk_scores[i, 0] / pred_len)
<= self.best_scores[b]) or \ <= self.best_scores[b]) or \
self.is_finished[i].all() self.is_finished[i].all()
@ -270,7 +300,7 @@ class BeamSearch(DecodeStrategy):
if n >= self.n_best: if n >= self.n_best:
break break
self.scores[b].append(score) self.scores[b].append(score)
self.predictions[b].append(pred) self.predictions[b].append(pred) # ``(batch, n_best,)``
self.attention[b].append( self.attention[b].append(
attn if attn is not None else []) attn if attn is not None else [])
else: else:
@ -307,3 +337,68 @@ class BeamSearch(DecodeStrategy):
if self._stepwise_cov_pen: if self._stepwise_cov_pen:
self._prev_penalty = self._prev_penalty.index_select( self._prev_penalty = self._prev_penalty.index_select(
0, non_finished) 0, non_finished)
class GNMTGlobalScorer(object):
"""NMT re-ranking.
Args:
alpha (float): Length parameter.
beta (float): Coverage parameter.
length_penalty (str): Length penalty strategy.
coverage_penalty (str): Coverage penalty strategy.
Attributes:
alpha (float): See above.
beta (float): See above.
length_penalty (callable): See :class:`penalties.PenaltyBuilder`.
coverage_penalty (callable): See :class:`penalties.PenaltyBuilder`.
has_cov_pen (bool): See :class:`penalties.PenaltyBuilder`.
has_len_pen (bool): See :class:`penalties.PenaltyBuilder`.
"""
@classmethod
def from_opt(cls, opt):
return cls(
opt.alpha,
opt.beta,
opt.length_penalty,
opt.coverage_penalty)
def __init__(self, alpha, beta, length_penalty, coverage_penalty):
self._validate(alpha, beta, length_penalty, coverage_penalty)
self.alpha = alpha
self.beta = beta
penalty_builder = penalties.PenaltyBuilder(coverage_penalty,
length_penalty)
self.has_cov_pen = penalty_builder.has_cov_pen
# Term will be subtracted from probability
self.cov_penalty = penalty_builder.coverage_penalty
self.has_len_pen = penalty_builder.has_len_pen
# Probability will be divided by this
self.length_penalty = penalty_builder.length_penalty
@classmethod
def _validate(cls, alpha, beta, length_penalty, coverage_penalty):
# these warnings indicate that either the alpha/beta
# forces a penalty to be a no-op, or a penalty is a no-op but
# the alpha/beta would suggest otherwise.
if length_penalty is None or length_penalty == "none":
if alpha != 0:
warnings.warn("Non-default `alpha` with no length penalty. "
"`alpha` has no effect.")
else:
# using some length penalty
if length_penalty == "wu" and alpha == 0.:
warnings.warn("Using length penalty Wu with alpha==0 "
"is equivalent to using length penalty none.")
if coverage_penalty is None or coverage_penalty == "none":
if beta != 0:
warnings.warn("Non-default `beta` with no coverage penalty. "
"`beta` has no effect.")
else:
# using some coverage penalty
if beta == 0.:
warnings.warn("Non-default coverage penalty with beta==0 "
"is equivalent to using coverage penalty none.")

Просмотреть файл

@ -9,7 +9,6 @@ class DecodeStrategy(object):
bos (int): Magic integer in output vocab. bos (int): Magic integer in output vocab.
eos (int): Magic integer in output vocab. eos (int): Magic integer in output vocab.
batch_size (int): Current batch size. batch_size (int): Current batch size.
device (torch.device or str): Device for memory bank (encoder).
parallel_paths (int): Decoding strategies like beam search parallel_paths (int): Decoding strategies like beam search
use parallel paths. Each batch is repeated ``parallel_paths`` use parallel paths. Each batch is repeated ``parallel_paths``
times in relevant state tensors. times in relevant state tensors.
@ -54,7 +53,7 @@ class DecodeStrategy(object):
done (bool): See above. done (bool): See above.
""" """
def __init__(self, pad, bos, eos, batch_size, device, parallel_paths, def __init__(self, pad, bos, eos, batch_size, parallel_paths,
min_length, block_ngram_repeat, exclusion_tokens, min_length, block_ngram_repeat, exclusion_tokens,
return_attention, max_length): return_attention, max_length):
@ -63,27 +62,43 @@ class DecodeStrategy(object):
self.bos = bos self.bos = bos
self.eos = eos self.eos = eos
self.batch_size = batch_size
self.parallel_paths = parallel_paths
# result caching # result caching
self.predictions = [[] for _ in range(batch_size)] self.predictions = [[] for _ in range(batch_size)]
self.scores = [[] for _ in range(batch_size)] self.scores = [[] for _ in range(batch_size)]
self.attention = [[] for _ in range(batch_size)] self.attention = [[] for _ in range(batch_size)]
self.alive_seq = torch.full(
[batch_size * parallel_paths, 1], self.bos,
dtype=torch.long, device=device)
self.is_finished = torch.zeros(
[batch_size, parallel_paths],
dtype=torch.uint8, device=device)
self.alive_attn = None self.alive_attn = None
self.min_length = min_length self.min_length = min_length
self.max_length = max_length self.max_length = max_length
self.block_ngram_repeat = block_ngram_repeat self.block_ngram_repeat = block_ngram_repeat
n_paths = batch_size * parallel_paths
self.forbidden_tokens = [dict() for _ in range(n_paths)]
self.exclusion_tokens = exclusion_tokens self.exclusion_tokens = exclusion_tokens
self.return_attention = return_attention self.return_attention = return_attention
self.done = False self.done = False
def initialize(self, memory_bank, src_lengths, src_map=None, device=None):
"""DecodeStrategy subclasses should override :func:`initialize()`.
`initialize` should be called before all actions.
used to prepare necessary ingredients for decode.
"""
if device is None:
device = torch.device('cpu')
self.alive_seq = torch.full(
[self.batch_size * self.parallel_paths, 1], self.bos,
dtype=torch.long, device=device)
self.is_finished = torch.zeros(
[self.batch_size, self.parallel_paths],
dtype=torch.uint8, device=device)
return None, memory_bank, src_lengths, src_map
def __len__(self): def __len__(self):
return self.alive_seq.shape[1] return self.alive_seq.shape[1]
@ -98,25 +113,75 @@ class DecodeStrategy(object):
self.is_finished.fill_(1) self.is_finished.fill_(1)
def block_ngram_repeats(self, log_probs): def block_ngram_repeats(self, log_probs):
cur_len = len(self) """
if self.block_ngram_repeat > 0 and cur_len > 1: We prevent the beam from going in any direction that would repeat any
ngram of size <block_ngram_repeat> more thant once.
The way we do it: we maintain a list of all ngrams of size
<block_ngram_repeat> that is updated each time the beam advances, and
manually put any token that would lead to a repeated ngram to 0.
This improves on the previous version's complexity:
- previous version's complexity: batch_size * beam_size * len(self)
- current version's complexity: batch_size * beam_size
This improves on the previous version's accuracy;
- Previous version blocks the whole beam, whereas here we only
block specific tokens.
- Before the translation would fail when all beams contained
repeated ngrams. This is sure to never happen here.
"""
# we don't block nothing if the user doesn't want it
if self.block_ngram_repeat <= 0:
return
# we can't block nothing beam's too short
if len(self) < self.block_ngram_repeat:
return
n = self.block_ngram_repeat - 1
for path_idx in range(self.alive_seq.shape[0]): for path_idx in range(self.alive_seq.shape[0]):
# skip BOS # we check paths one by one
hyp = self.alive_seq[path_idx, 1:]
ngrams = set() current_ngram = tuple(self.alive_seq[path_idx, -n:].tolist())
fail = False forbidden_tokens = self.forbidden_tokens[path_idx].get(
gram = [] current_ngram, None)
for i in range(cur_len - 1): if forbidden_tokens is not None:
# Last n tokens, n = block_ngram_repeat log_probs[path_idx, list(forbidden_tokens)] = -10e20
gram = (gram + [hyp[i].item()])[-self.block_ngram_repeat:]
# skip the blocking if any token in gram is excluded def maybe_update_forbidden_tokens(self):
if set(gram) & self.exclusion_tokens: """We complete and reorder the list of forbidden_tokens"""
# we don't forbid nothing if the user doesn't want it
if self.block_ngram_repeat <= 0:
return
# we can't forbid nothing if beam's too short
if len(self) < self.block_ngram_repeat:
return
n = self.block_ngram_repeat
forbidden_tokens = list()
for path_idx, seq in zip(self.select_indices, self.alive_seq):
# Reordering forbidden_tokens following beam selection
# We rebuild a dict to ensure we get the value and not the pointer
forbidden_tokens.append(
dict(self.forbidden_tokens[path_idx]))
# Grabing the newly selected tokens and associated ngram
current_ngram = tuple(seq[-n:].tolist())
# skip the blocking if any token in current_ngram is excluded
if set(current_ngram) & self.exclusion_tokens:
continue continue
if tuple(gram) in ngrams:
fail = True forbidden_tokens[-1].setdefault(current_ngram[:-1], set())
ngrams.add(tuple(gram)) forbidden_tokens[-1][current_ngram[:-1]].add(current_ngram[-1])
if fail:
log_probs[path_idx] = -10e20 self.forbidden_tokens = forbidden_tokens
def advance(self, log_probs, attn): def advance(self, log_probs, attn):
"""DecodeStrategy subclasses should override :func:`advance()`. """DecodeStrategy subclasses should override :func:`advance()`.

Просмотреть файл

@ -56,7 +56,7 @@ def sample_with_temperature(logits, sampling_temp, keep_topk):
return topk_ids, topk_scores return topk_ids, topk_scores
class RandomSampling(DecodeStrategy): class GreedySearch(DecodeStrategy):
"""Select next tokens randomly from the top k possible next tokens. """Select next tokens randomly from the top k possible next tokens.
The ``scores`` attribute's lists are the score, after applying temperature, The ``scores`` attribute's lists are the score, after applying temperature,
@ -68,7 +68,6 @@ class RandomSampling(DecodeStrategy):
bos (int): See base. bos (int): See base.
eos (int): See base. eos (int): See base.
batch_size (int): See base. batch_size (int): See base.
device (torch.device or str): See base ``device``.
min_length (int): See base. min_length (int): See base.
max_length (int): See base. max_length (int): See base.
block_ngram_repeat (int): See base. block_ngram_repeat (int): See base.
@ -76,30 +75,49 @@ class RandomSampling(DecodeStrategy):
return_attention (bool): See base. return_attention (bool): See base.
max_length (int): See base. max_length (int): See base.
sampling_temp (float): See sampling_temp (float): See
:func:`~onmt.translate.random_sampling.sample_with_temperature()`. :func:`~onmt.translate.greedy_search.sample_with_temperature()`.
keep_topk (int): See keep_topk (int): See
:func:`~onmt.translate.random_sampling.sample_with_temperature()`. :func:`~onmt.translate.greedy_search.sample_with_temperature()`.
memory_length (LongTensor): Lengths of encodings. Used for
masking attention.
""" """
def __init__(self, pad, bos, eos, batch_size, device, def __init__(self, pad, bos, eos, batch_size, min_length,
min_length, block_ngram_repeat, exclusion_tokens, block_ngram_repeat, exclusion_tokens, return_attention,
return_attention, max_length, sampling_temp, keep_topk, max_length, sampling_temp, keep_topk):
memory_length): assert block_ngram_repeat == 0
super(RandomSampling, self).__init__( super(GreedySearch, self).__init__(
pad, bos, eos, batch_size, device, 1, pad, bos, eos, batch_size, 1, min_length, block_ngram_repeat,
min_length, block_ngram_repeat, exclusion_tokens, exclusion_tokens, return_attention, max_length)
return_attention, max_length)
self.sampling_temp = sampling_temp self.sampling_temp = sampling_temp
self.keep_topk = keep_topk self.keep_topk = keep_topk
self.topk_scores = None self.topk_scores = None
self.memory_length = memory_length
self.batch_size = batch_size def initialize(self, memory_bank, src_lengths, src_map=None, device=None):
self.select_indices = torch.arange(self.batch_size, """Initialize for decoding."""
dtype=torch.long, device=device) fn_map_state = None
self.original_batch_idx = torch.arange(self.batch_size,
dtype=torch.long, device=device) if isinstance(memory_bank, tuple):
mb_device = memory_bank[0].device
else:
mb_device = memory_bank.device
if device is None:
device = mb_device
self.memory_lengths = src_lengths
super(GreedySearch, self).initialize(
memory_bank, src_lengths, src_map, device)
self.select_indices = torch.arange(
self.batch_size, dtype=torch.long, device=device)
self.original_batch_idx = torch.arange(
self.batch_size, dtype=torch.long, device=device)
return fn_map_state, memory_bank, self.memory_lengths, src_map
@property
def current_predictions(self):
return self.alive_seq[:, -1]
@property
def batch_offset(self):
return self.select_indices
def advance(self, log_probs, attn): def advance(self, log_probs, attn):
"""Select next tokens randomly from the top k possible next tokens. """Select next tokens randomly from the top k possible next tokens.
@ -138,7 +156,7 @@ class RandomSampling(DecodeStrategy):
self.scores[b_orig].append(self.topk_scores[b, 0]) self.scores[b_orig].append(self.topk_scores[b, 0])
self.predictions[b_orig].append(self.alive_seq[b, 1:]) self.predictions[b_orig].append(self.alive_seq[b, 1:])
self.attention[b_orig].append( self.attention[b_orig].append(
self.alive_attn[:, b, :self.memory_length[b]] self.alive_attn[:, b, :self.memory_lengths[b]]
if self.alive_attn is not None else []) if self.alive_attn is not None else [])
self.done = self.is_finished.all() self.done = self.is_finished.all()
if self.done: if self.done:

Просмотреть файл

@ -3,23 +3,21 @@
from __future__ import print_function from __future__ import print_function
import codecs import codecs
import os import os
import math
import time import time
from itertools import count import numpy as np
from itertools import count, zip_longest
import torch import torch
import onmt.model_builder import onmt.model_builder
import onmt.translate.beam
import onmt.inputters as inputters import onmt.inputters as inputters
import onmt.decoders.ensemble import onmt.decoders.ensemble
from onmt.translate.beam_search import BeamSearch from onmt.translate.beam_search import BeamSearch
from onmt.translate.random_sampling import RandomSampling from onmt.translate.greedy_search import GreedySearch
from onmt.utils.misc import tile, set_random_seed from onmt.utils.misc import tile, set_random_seed, report_matrix
from onmt.utils.alignment import extract_alignment, build_align_pharaoh
from onmt.modules.copy_generator import collapse_copy_scores from onmt.modules.copy_generator import collapse_copy_scores
import editdistance
def build_translator(opt, report_score=True, logger=None, out_file=None): def build_translator(opt, report_score=True, logger=None, out_file=None):
if out_file is None: if out_file is None:
@ -38,12 +36,32 @@ def build_translator(opt, report_score=True, logger=None, out_file=None):
model_opt, model_opt,
global_scorer=scorer, global_scorer=scorer,
out_file=out_file, out_file=out_file,
report_align=opt.report_align,
report_score=report_score, report_score=report_score,
logger=logger logger=logger
) )
return translator return translator
def max_tok_len(new, count, sofar):
"""
In token batching scheme, the number of sequences is limited
such that the total number of src/tgt tokens (including padding)
in a batch <= batch_size
"""
# Maintains the longest src and tgt length in the current batch
global max_src_in_batch # this is a hack
# Reset current longest length at a new batch (count=1)
if count == 1:
max_src_in_batch = 0
# max_tgt_in_batch = 0
# Src: [<bos> w1 ... wN <eos>]
max_src_in_batch = max(max_src_in_batch, len(new.src[0]) + 2)
# Tgt: [w1 ... wM <eos>]
src_elements = count * max_src_in_batch
return src_elements
class Translator(object): class Translator(object):
"""Translate a batch of sentences with a saved model. """Translate a batch of sentences with a saved model.
@ -61,9 +79,9 @@ class Translator(object):
:class:`onmt.translate.decode_strategy.DecodeStrategy`. :class:`onmt.translate.decode_strategy.DecodeStrategy`.
beam_size (int): Number of beams. beam_size (int): Number of beams.
random_sampling_topk (int): See random_sampling_topk (int): See
:class:`onmt.translate.random_sampling.RandomSampling`. :class:`onmt.translate.greedy_search.GreedySearch`.
random_sampling_temp (int): See random_sampling_temp (int): See
:class:`onmt.translate.random_sampling.RandomSampling`. :class:`onmt.translate.greedy_search.GreedySearch`.
stepwise_penalty (bool): Whether coverage penalty is applied every step stepwise_penalty (bool): Whether coverage penalty is applied every step
or not. or not.
dump_beam (bool): Debugging option. dump_beam (bool): Debugging option.
@ -74,8 +92,6 @@ class Translator(object):
replace_unk (bool): Replace unknown token. replace_unk (bool): Replace unknown token.
data_type (str): Source data type. data_type (str): Source data type.
verbose (bool): Print/log every translation. verbose (bool): Print/log every translation.
report_bleu (bool): Print/log Bleu metric.
report_rouge (bool): Print/log Rouge metric.
report_time (bool): Print/log total time/frequency. report_time (bool): Print/log total time/frequency.
copy_attn (bool): Use copy attention. copy_attn (bool): Use copy attention.
global_scorer (onmt.translate.GNMTGlobalScorer): Translation global_scorer (onmt.translate.GNMTGlobalScorer): Translation
@ -104,14 +120,14 @@ class Translator(object):
block_ngram_repeat=0, block_ngram_repeat=0,
ignore_when_blocking=frozenset(), ignore_when_blocking=frozenset(),
replace_unk=False, replace_unk=False,
phrase_table="",
data_type="text", data_type="text",
verbose=False, verbose=False,
report_bleu=False,
report_rouge=False,
report_time=False, report_time=False,
copy_attn=False, copy_attn=False,
global_scorer=None, global_scorer=None,
out_file=None, out_file=None,
report_align=False,
report_score=True, report_score=True,
logger=None, logger=None,
seed=-1): seed=-1):
@ -151,10 +167,9 @@ class Translator(object):
if self.replace_unk and not self.model.decoder.attentional: if self.replace_unk and not self.model.decoder.attentional:
raise ValueError( raise ValueError(
"replace_unk requires an attentional decoder.") "replace_unk requires an attentional decoder.")
self.phrase_table = phrase_table
self.data_type = data_type self.data_type = data_type
self.verbose = verbose self.verbose = verbose
self.report_bleu = report_bleu
self.report_rouge = report_rouge
self.report_time = report_time self.report_time = report_time
self.copy_attn = copy_attn self.copy_attn = copy_attn
@ -165,6 +180,7 @@ class Translator(object):
raise ValueError( raise ValueError(
"Coverage penalty requires an attentional decoder.") "Coverage penalty requires an attentional decoder.")
self.out_file = out_file self.out_file = out_file
self.report_align = report_align
self.report_score = report_score self.report_score = report_score
self.logger = logger self.logger = logger
@ -192,6 +208,7 @@ class Translator(object):
model_opt, model_opt,
global_scorer=None, global_scorer=None,
out_file=None, out_file=None,
report_align=False,
report_score=True, report_score=True,
logger=None): logger=None):
"""Alternate constructor. """Alternate constructor.
@ -207,6 +224,7 @@ class Translator(object):
:func:`__init__()`.. :func:`__init__()`..
out_file (TextIO or codecs.StreamReaderWriter): See out_file (TextIO or codecs.StreamReaderWriter): See
:func:`__init__()`. :func:`__init__()`.
report_align (bool) : See :func:`__init__()`.
report_score (bool) : See :func:`__init__()`. report_score (bool) : See :func:`__init__()`.
logger (logging.Logger or NoneType): See :func:`__init__()`. logger (logging.Logger or NoneType): See :func:`__init__()`.
""" """
@ -231,14 +249,14 @@ class Translator(object):
block_ngram_repeat=opt.block_ngram_repeat, block_ngram_repeat=opt.block_ngram_repeat,
ignore_when_blocking=set(opt.ignore_when_blocking), ignore_when_blocking=set(opt.ignore_when_blocking),
replace_unk=opt.replace_unk, replace_unk=opt.replace_unk,
phrase_table=opt.phrase_table,
data_type=opt.data_type, data_type=opt.data_type,
verbose=opt.verbose, verbose=opt.verbose,
report_bleu=opt.report_bleu,
report_rouge=opt.report_rouge,
report_time=opt.report_time, report_time=opt.report_time,
copy_attn=model_opt.copy_attn, copy_attn=model_opt.copy_attn,
global_scorer=global_scorer, global_scorer=global_scorer,
out_file=out_file, out_file=out_file,
report_align=report_align,
report_score=report_score, report_score=report_score,
logger=logger, logger=logger,
seed=opt.seed) seed=opt.seed)
@ -266,7 +284,10 @@ class Translator(object):
tgt=None, tgt=None,
src_dir=None, src_dir=None,
batch_size=None, batch_size=None,
batch_type="sents",
attn_debug=False, attn_debug=False,
align_debug=False,
phrase_table="",
partial=None, partial=None,
dymax_len=None): dymax_len=None):
"""Translate content of ``src`` and get gold scores from ``tgt``. """Translate content of ``src`` and get gold scores from ``tgt``.
@ -278,6 +299,7 @@ class Translator(object):
for certain types of data). for certain types of data).
batch_size (int): size of examples per mini-batch batch_size (int): size of examples per mini-batch
attn_debug (bool): enables the attention logging attn_debug (bool): enables the attention logging
align_debug (bool): enables the word alignment logging
Returns: Returns:
(`list`, `list`) (`list`, `list`)
@ -327,12 +349,16 @@ class Translator(object):
if batch_size is None: if batch_size is None:
raise ValueError("batch_size must be set") raise ValueError("batch_size must be set")
src_data = {"reader": self.src_reader, "data": src, "dir": src_dir}
tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None}
_readers, _data, _dir = inputters.Dataset.config(
[('src', src_data), ('tgt', tgt_data)])
# corpus_id field is useless here
if self.fields.get("corpus_id", None) is not None:
self.fields.pop('corpus_id')
data = inputters.Dataset( data = inputters.Dataset(
self.fields, self.fields, readers=_readers, data=_data, dirs=_dir,
readers=([self.src_reader, self.tgt_reader]
if tgt else [self.src_reader]),
data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)],
dirs=[src_dir, None] if tgt else [src_dir],
sort_key=inputters.str2sortkey[self.data_type], sort_key=inputters.str2sortkey[self.data_type],
filter_pred=self._filter_pred filter_pred=self._filter_pred
) )
@ -341,6 +367,7 @@ class Translator(object):
dataset=data, dataset=data,
device=self._dev, device=self._dev,
batch_size=batch_size, batch_size=batch_size,
batch_size_fn=max_tok_len if batch_type == "tokens" else None,
train=False, train=False,
sort=False, sort=False,
sort_within_batch=True, sort_within_batch=True,
@ -348,7 +375,8 @@ class Translator(object):
) )
xlation_builder = onmt.translate.TranslationBuilder( xlation_builder = onmt.translate.TranslationBuilder(
data, self.fields, self.n_best, self.replace_unk, tgt data, self.fields, self.n_best, self.replace_unk, tgt,
self.phrase_table
) )
# Statistics # Statistics
@ -377,6 +405,14 @@ class Translator(object):
n_best_preds = [" ".join(pred) n_best_preds = [" ".join(pred)
for pred in trans.pred_sents[:self.n_best]] for pred in trans.pred_sents[:self.n_best]]
if self.report_align:
align_pharaohs = [build_align_pharaoh(align) for align
in trans.word_aligns[:self.n_best]]
n_best_preds_align = [" ".join(align) for align
in align_pharaohs]
n_best_preds = [pred + " ||| " + align
for pred, align in zip(
n_best_preds, n_best_preds_align)]
all_predictions += [n_best_preds] all_predictions += [n_best_preds]
self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.write('\n'.join(n_best_preds) + '\n')
self.out_file.flush() self.out_file.flush()
@ -397,27 +433,28 @@ class Translator(object):
srcs = trans.src_raw srcs = trans.src_raw
else: else:
srcs = [str(item) for item in range(len(attns[0]))] srcs = [str(item) for item in range(len(attns[0]))]
header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) output = report_matrix(srcs, preds, attns)
row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs)
output = header_format.format("", *srcs) + '\n' if self.logger:
covatn = [] self.logger.info(output)
covatn2d = [] else:
print(len(preds),len(attns)) os.write(1, output.encode('utf-8'))
for word, row in zip(preds, attns):
max_index = row.index(max(row)) if align_debug:
# if not covatn: if trans.gold_sent is not None:
# covatn = [0]*len(row) tgts = trans.gold_sent
# covatn[max_index] += 1 else:
covatn2d.append(row) tgts = trans.pred_sents[0]
row_format = row_format.replace( align = trans.word_aligns[0].tolist()
"{:>10.7f} ", "{:*>10.7f} ", max_index + 1) if self.data_type == 'text':
row_format = row_format.replace( srcs = trans.src_raw
"{:*>10.7f} ", "{:>10.7f} ", max_index) else:
output += row_format.format(word, *row) + '\n' srcs = [str(item) for item in range(len(align[0]))]
row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) output = report_matrix(srcs, tgts, align)
print(output) if self.logger:
# print(covatn2d) self.logger.info(output)
# os.write(1, output.encode('utf-8')) else:
os.write(1, output.encode('utf-8'))
end_time = time.time() end_time = time.time()
@ -429,12 +466,6 @@ class Translator(object):
msg = self._report_score('GOLD', gold_score_total, msg = self._report_score('GOLD', gold_score_total,
gold_words_total) gold_words_total)
self._log(msg) self._log(msg)
if self.report_bleu:
msg = self._report_bleu(tgt)
self._log(msg)
if self.report_rouge:
msg = self._report_rouge(tgt)
self._log(msg)
if self.report_time: if self.report_time:
total_time = end_time - start_time total_time = end_time - start_time
@ -448,124 +479,126 @@ class Translator(object):
import json import json
json.dump(self.translator.beam_accum, json.dump(self.translator.beam_accum,
codecs.open(self.dump_beam, 'w', 'utf-8')) codecs.open(self.dump_beam, 'w', 'utf-8'))
if attn_debug: if attn_debug:
return all_scores, all_predictions, covatn2d return all_scores, all_predictions, attns
else: else:
return all_scores, all_predictions return all_scores, all_predictions
def _translate_random_sampling( def _align_pad_prediction(self, predictions, bos, pad):
self, """
batch, Padding predictions in batch and add BOS.
src_vocabs,
max_length,
min_length=0,
sampling_temp=1.0,
keep_topk=-1,
return_attention=False):
"""Alternative to beam search. Do random sampling at each step."""
assert self.beam_size == 1 Args:
predictions (List[List[Tensor]]): `(batch, n_best,)`, for each src
sequence contain n_best tgt predictions all of which ended with
eos id.
bos (int): bos index to be used.
pad (int): pad index to be used.
# TODO: support these blacklisted features. Return:
assert self.block_ngram_repeat == 0 batched_nbest_predict (torch.LongTensor): `(batch, n_best, tgt_l)`
"""
dtype, device = predictions[0][0].dtype, predictions[0][0].device
flatten_tgt = [best.tolist() for bests in predictions
for best in bests]
paded_tgt = torch.tensor(
list(zip_longest(*flatten_tgt, fillvalue=pad)),
dtype=dtype, device=device).T
bos_tensor = torch.full([paded_tgt.size(0), 1], bos,
dtype=dtype, device=device)
full_tgt = torch.cat((bos_tensor, paded_tgt), dim=-1)
batched_nbest_predict = full_tgt.view(
len(predictions), -1, full_tgt.size(-1)) # (batch, n_best, tgt_l)
return batched_nbest_predict
batch_size = batch.batch_size def _align_forward(self, batch, predictions):
"""
For a batch of input and its prediction, return a list of batch predict
alignment src indice Tensor in size ``(batch, n_best,)``.
"""
# (0) add BOS and padding to tgt prediction
if hasattr(batch, 'tgt'):
batch_tgt_idxs = batch.tgt.transpose(1, 2).transpose(0, 2)
else:
batch_tgt_idxs = self._align_pad_prediction(
predictions, bos=self._tgt_bos_idx, pad=self._tgt_pad_idx)
tgt_mask = (batch_tgt_idxs.eq(self._tgt_pad_idx) |
batch_tgt_idxs.eq(self._tgt_eos_idx) |
batch_tgt_idxs.eq(self._tgt_bos_idx))
# Encoder forward. n_best = batch_tgt_idxs.size(1)
# (1) Encoder forward.
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
# (2) Repeat src objects `n_best` times.
# We use batch_size x n_best, get ``(src_len, batch * n_best, nfeat)``
src = tile(src, n_best, dim=1)
enc_states = tile(enc_states, n_best, dim=1)
if isinstance(memory_bank, tuple):
memory_bank = tuple(tile(x, n_best, dim=1) for x in memory_bank)
else:
memory_bank = tile(memory_bank, n_best, dim=1)
src_lengths = tile(src_lengths, n_best) # ``(batch * n_best,)``
# (3) Init decoder with n_best src,
self.model.decoder.init_state(src, memory_bank, enc_states) self.model.decoder.init_state(src, memory_bank, enc_states)
# reshape tgt to ``(len, batch * n_best, nfeat)``
tgt = batch_tgt_idxs.view(-1, batch_tgt_idxs.size(-1)).T.unsqueeze(-1)
dec_in = tgt[:-1] # exclude last target from inputs
_, attns = self.model.decoder(
dec_in, memory_bank, memory_lengths=src_lengths, with_align=True)
use_src_map = self.copy_attn alignment_attn = attns["align"] # ``(B, tgt_len-1, src_len)``
# masked_select
results = { align_tgt_mask = tgt_mask.view(-1, tgt_mask.size(-1))
"predictions": None, prediction_mask = align_tgt_mask[:, 1:] # exclude bos to match pred
"scores": None, # get aligned src id for each prediction's valid tgt tokens
"attention": None, alignement = extract_alignment(
"batch": batch, alignment_attn, prediction_mask, src_lengths, n_best)
"gold_score": self._gold_score( return alignement
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
enc_states, batch_size, src)}
memory_lengths = src_lengths
src_map = batch.src_map if use_src_map else None
if isinstance(memory_bank, tuple):
mb_device = memory_bank[0].device
else:
mb_device = memory_bank.device
random_sampler = RandomSampling(
self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx,
batch_size, mb_device, min_length, self.block_ngram_repeat,
self._exclusion_idxs, return_attention, self.max_length,
sampling_temp, keep_topk, memory_lengths)
for step in range(max_length):
# Shape: (1, B, 1)
decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1)
log_probs, attn = self._decode_and_generate(
decoder_input,
memory_bank,
batch,
src_vocabs,
memory_lengths=memory_lengths,
src_map=src_map,
step=step,
batch_offset=random_sampler.select_indices
)
random_sampler.advance(log_probs, attn)
any_batch_is_finished = random_sampler.is_finished.any()
if any_batch_is_finished:
random_sampler.update_finished()
if random_sampler.done:
break
if any_batch_is_finished:
select_indices = random_sampler.select_indices
# Reorder states.
if isinstance(memory_bank, tuple):
memory_bank = tuple(x.index_select(1, select_indices)
for x in memory_bank)
else:
memory_bank = memory_bank.index_select(1, select_indices)
memory_lengths = memory_lengths.index_select(0, select_indices)
if src_map is not None:
src_map = src_map.index_select(1, select_indices)
self.model.decoder.map_state(
lambda state, dim: state.index_select(dim, select_indices))
results["scores"] = random_sampler.scores
results["predictions"] = random_sampler.predictions
results["attention"] = random_sampler.attention
return results
def translate_batch(self, batch, src_vocabs, attn_debug): def translate_batch(self, batch, src_vocabs, attn_debug):
"""Translate a batch of sentences.""" """Translate a batch of sentences."""
with torch.no_grad(): with torch.no_grad():
if self.beam_size == 1: if self.beam_size == 1:
return self._translate_random_sampling( decode_strategy = GreedySearch(
batch, pad=self._tgt_pad_idx,
src_vocabs, bos=self._tgt_bos_idx,
self.max_length, eos=self._tgt_eos_idx,
min_length=self.min_length, batch_size=batch.batch_size,
min_length=self.min_length, max_length=self.max_length,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
return_attention=attn_debug or self.replace_unk,
sampling_temp=self.random_sampling_temp, sampling_temp=self.random_sampling_temp,
keep_topk=self.sample_from_topk, keep_topk=self.sample_from_topk)
return_attention=attn_debug or self.replace_unk)
else: else:
return self._translate_batch( # TODO: support these blacklisted features
batch, assert not self.dump_beam
src_vocabs, max_length = self.max_length
self.max_length, if self.dymax_len is not None:
min_length=self.min_length, if self.partial:
ratio=self.ratio, max_length = len(self.partial) + 3
if not self.partial and self.partialf:
max_length = 3
decode_strategy = BeamSearch(
self.beam_size,
batch_size=batch.batch_size,
pad=self._tgt_pad_idx,
bos=self._tgt_bos_idx,
eos=self._tgt_eos_idx,
n_best=self.n_best, n_best=self.n_best,
return_attention=attn_debug or self.replace_unk) global_scorer=self.global_scorer,
min_length=self.min_length, max_length=max_length,
return_attention=attn_debug or self.replace_unk,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
stepwise_penalty=self.stepwise_penalty,
ratio=self.ratio)
return self._translate_batch_with_strategy(batch, src_vocabs,
decode_strategy)
def _run_encoder(self, batch): def _run_encoder(self, batch):
src, src_lengths = batch.src if isinstance(batch.src, tuple) \ src, src_lengths = batch.src if isinstance(batch.src, tuple) \
@ -622,7 +655,8 @@ class Translator(object):
src_map) src_map)
# here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
if batch_offset is None: if batch_offset is None:
scores = scores.view(batch.batch_size, -1, scores.size(-1)) scores = scores.view(-1, batch.batch_size, scores.size(-1))
scores = scores.transpose(0, 1).contiguous()
else: else:
scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = scores.view(-1, self.beam_size, scores.size(-1))
scores = collapse_copy_scores( scores = collapse_copy_scores(
@ -639,25 +673,25 @@ class Translator(object):
# or [ tgt_len, batch_size, vocab ] when full sentence # or [ tgt_len, batch_size, vocab ] when full sentence
return log_probs, attn return log_probs, attn
def _translate_batch( def _translate_batch_with_strategy(
self, self,
batch, batch,
src_vocabs, src_vocabs,
max_length, decode_strategy):
min_length=0, """Translate a batch of sentences step by step using cache.
ratio=0.,
n_best=1, Args:
return_attention=False): batch: a batch of sentences, yield by data iterator.
# TODO: support these blacklisted features. src_vocabs (list): list of torchtext.data.Vocab if can_copy.
assert not self.dump_beam decode_strategy (DecodeStrategy): A decode strategy to use for
if self.dymax_len is not None: generate translation step by step.
if self.partial:
max_length = len(self.partial) + 3 Returns:
if not self.partial and self.partialf: results (dict): The translation results.
max_length = 3 """
# (0) Prep the components of the search. # (0) Prep the components of the search.
use_src_map = self.copy_attn use_src_map = self.copy_attn
beam_size = self.beam_size parallel_paths = decode_strategy.parallel_paths # beam_size
batch_size = batch.batch_size batch_size = batch.batch_size
# (1) Run the encoder on the src. # (1) Run the encoder on the src.
@ -673,42 +707,16 @@ class Translator(object):
batch, memory_bank, src_lengths, src_vocabs, use_src_map, batch, memory_bank, src_lengths, src_vocabs, use_src_map,
enc_states, batch_size, src)} enc_states, batch_size, src)}
# (2) Repeat src objects `beam_size` times. # (2) prep decode_strategy. Possibly repeat src objects.
# We use batch_size x beam_size src_map = batch.src_map if use_src_map else None
src_map = (tile(batch.src_map, beam_size, dim=1) fn_map_state, memory_bank, memory_lengths, src_map = \
if use_src_map else None) decode_strategy.initialize(memory_bank, src_lengths, src_map)
self.model.decoder.map_state( if fn_map_state is not None:
lambda state, dim: tile(state, beam_size, dim=dim)) self.model.decoder.map_state(fn_map_state)
if isinstance(memory_bank, tuple): # (3) Begin decoding step by step:
memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) for step in range(decode_strategy.max_length):
mb_device = memory_bank[0].device decoder_input = decode_strategy.current_predictions.view(1, -1, 1)
else:
memory_bank = tile(memory_bank, beam_size, dim=1)
mb_device = memory_bank.device
memory_lengths = tile(src_lengths, beam_size)
# (0) pt 2, prep the beam object
beam = BeamSearch(
beam_size,
n_best=n_best,
batch_size=batch_size,
global_scorer=self.global_scorer,
pad=self._tgt_pad_idx,
eos=self._tgt_eos_idx,
bos=self._tgt_bos_idx,
min_length=min_length,
ratio=ratio,
max_length=max_length,
mb_device=mb_device,
return_attention=return_attention,
stepwise_penalty=self.stepwise_penalty,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
memory_lengths=memory_lengths)
for step in range(max_length):
decoder_input = beam.current_predictions.view(1, -1, 1)
log_probs, attn = self._decode_and_generate( log_probs, attn = self._decode_and_generate(
decoder_input, decoder_input,
@ -718,18 +726,18 @@ class Translator(object):
memory_lengths=memory_lengths, memory_lengths=memory_lengths,
src_map=src_map, src_map=src_map,
step=step, step=step,
batch_offset=beam._batch_offset) batch_offset=decode_strategy.batch_offset)
beam.advance(log_probs, attn, self.partial, self.partialf) decode_strategy.advance(log_probs, attn, self.partial, self.partialf)
any_beam_is_finished = beam.is_finished.any() any_finished = decode_strategy.is_finished.any()
if any_beam_is_finished: if any_finished:
beam.update_finished() decode_strategy.update_finished()
if beam.done: if decode_strategy.done:
break break
select_indices = beam.current_origin select_indices = decode_strategy.select_indices
if any_beam_is_finished: if any_finished:
# Reorder states. # Reorder states.
if isinstance(memory_bank, tuple): if isinstance(memory_bank, tuple):
memory_bank = tuple(x.index_select(1, select_indices) memory_bank = tuple(x.index_select(1, select_indices)
@ -742,107 +750,18 @@ class Translator(object):
if src_map is not None: if src_map is not None:
src_map = src_map.index_select(1, select_indices) src_map = src_map.index_select(1, select_indices)
if parallel_paths > 1 or any_finished:
self.model.decoder.map_state( self.model.decoder.map_state(
lambda state, dim: state.index_select(dim, select_indices)) lambda state, dim: state.index_select(dim, select_indices))
results["scores"] = beam.scores results["scores"] = decode_strategy.scores
results["predictions"] = beam.predictions results["predictions"] = decode_strategy.predictions
results["attention"] = beam.attention results["attention"] = decode_strategy.attention
return results if self.report_align:
results["alignment"] = self._align_forward(
# This is left in the code for now, but unsued batch, decode_strategy.predictions)
def _translate_batch_deprecated(self, batch, src_vocabs):
# (0) Prep each of the components of the search.
# And helper method for reducing verbosity.
use_src_map = self.copy_attn
beam_size = self.beam_size
batch_size = batch.batch_size
beam = [onmt.translate.Beam(
beam_size,
n_best=self.n_best,
cuda=self.cuda,
global_scorer=self.global_scorer,
pad=self._tgt_pad_idx,
eos=self._tgt_eos_idx,
bos=self._tgt_bos_idx,
min_length=self.min_length,
stepwise_penalty=self.stepwise_penalty,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs)
for __ in range(batch_size)]
# (1) Run the encoder on the src.
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
self.model.decoder.init_state(src, memory_bank, enc_states)
results = {
"predictions": [],
"scores": [],
"attention": [],
"batch": batch,
"gold_score": self._gold_score(
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
enc_states, batch_size, src)}
# (2) Repeat src objects `beam_size` times.
# We use now batch_size x beam_size (same as fast mode)
src_map = (tile(batch.src_map, beam_size, dim=1)
if use_src_map else None)
self.model.decoder.map_state(
lambda state, dim: tile(state, beam_size, dim=dim))
if isinstance(memory_bank, tuple):
memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
else: else:
memory_bank = tile(memory_bank, beam_size, dim=1) results["alignment"] = [[] for _ in range(batch_size)]
memory_lengths = tile(src_lengths, beam_size)
# (3) run the decoder to generate sentences, using beam search.
for i in range(self.max_length):
if all((b.done for b in beam)):
break
# (a) Construct batch x beam_size nxt words.
# Get all the pending current beam words and arrange for forward.
inp = torch.stack([b.current_predictions for b in beam])
inp = inp.view(1, -1, 1)
# (b) Decode and forward
out, beam_attn = self._decode_and_generate(
inp, memory_bank, batch, src_vocabs,
memory_lengths=memory_lengths, src_map=src_map, step=i
)
out = out.view(batch_size, beam_size, -1)
beam_attn = beam_attn.view(batch_size, beam_size, -1)
# (c) Advance each beam.
select_indices_array = []
# Loop over the batch_size number of beam
for j, b in enumerate(beam):
if not b.done:
b.advance(out[j, :],
beam_attn.data[j, :, :memory_lengths[j]])
select_indices_array.append(
b.current_origin + j * beam_size)
select_indices = torch.cat(select_indices_array)
self.model.decoder.map_state(
lambda state, dim: state.index_select(dim, select_indices))
# (4) Extract sentences from beam.
for b in beam:
scores, ks = b.sort_finished(minimum=self.n_best)
hyps, attn = [], []
for times, k in ks[:self.n_best]:
hyp, att = b.get_hyp(times, k)
hyps.append(hyp)
attn.append(att)
results["predictions"].append(hyps)
results["scores"].append(scores)
results["attention"].append(attn)
return results return results
def _score_target(self, batch, memory_bank, src_lengths, def _score_target(self, batch, memory_bank, src_lengths,
@ -855,7 +774,7 @@ class Translator(object):
memory_lengths=src_lengths, src_map=src_map) memory_lengths=src_lengths, src_map=src_map)
log_probs[:, :, self._tgt_pad_idx] = 0 log_probs[:, :, self._tgt_pad_idx] = 0
gold = tgt_in gold = tgt[1:]
gold_scores = log_probs.gather(2, gold) gold_scores = log_probs.gather(2, gold)
gold_scores = gold_scores.sum(dim=0).view(-1) gold_scores = gold_scores.sum(dim=0).view(-1)
@ -865,31 +784,9 @@ class Translator(object):
if words_total == 0: if words_total == 0:
msg = "%s No words predicted" % (name,) msg = "%s No words predicted" % (name,)
else: else:
avg_score = score_total / words_total
ppl = np.exp(-score_total.item() / words_total)
msg = ("%s AVG SCORE: %.4f, %s PPL: %.4f" % ( msg = ("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
name, score_total / words_total, name, avg_score,
name, math.exp(-score_total / words_total))) name, ppl))
return msg
def _report_bleu(self, tgt_path):
import subprocess
base_dir = os.path.abspath(__file__ + "/../../..")
# Rollback pointer to the beginning.
self.out_file.seek(0)
print()
res = subprocess.check_output(
"perl %s/tools/multi-bleu.perl %s" % (base_dir, tgt_path),
stdin=self.out_file, shell=True
).decode("utf-8")
msg = ">> " + res.strip()
return msg
def _report_rouge(self, tgt_path):
import subprocess
path = os.path.split(os.path.realpath(__file__))[0]
msg = subprocess.check_output(
"python %s/tools/test_rouge.py -r %s -c STDIN" % (path, tgt_path),
shell=True, stdin=self.out_file
).decode("utf-8").strip()
return msg return msg

Просмотреть файл

@ -3,31 +3,47 @@ from snownlp import SnowNLP
import pkuseg import pkuseg
def wrap_str_func(func):
"""
Wrapper to apply str function to the proper key of return_dict.
"""
def wrapper(some_dict):
some_dict["seg"] = [func(item) for item in some_dict["seg"]]
return some_dict
return wrapper
# Chinese segmentation # Chinese segmentation
def zh_segmentator(line): @wrap_str_func
def zh_segmentator(line, server_model):
return " ".join(pkuseg.pkuseg().cut(line)) return " ".join(pkuseg.pkuseg().cut(line))
# Chinese simplify -> Chinese traditional standard # Chinese simplify -> Chinese traditional standard
def zh_traditional_standard(line): @wrap_str_func
def zh_traditional_standard(line, server_model):
return HanLP.convertToTraditionalChinese(line) return HanLP.convertToTraditionalChinese(line)
# Chinese simplify -> Chinese traditional (HongKong) # Chinese simplify -> Chinese traditional (HongKong)
def zh_traditional_hk(line): @wrap_str_func
def zh_traditional_hk(line, server_model):
return HanLP.s2hk(line) return HanLP.s2hk(line)
# Chinese simplify -> Chinese traditional (Taiwan) # Chinese simplify -> Chinese traditional (Taiwan)
def zh_traditional_tw(line): @wrap_str_func
def zh_traditional_tw(line, server_model):
return HanLP.s2tw(line) return HanLP.s2tw(line)
# Chinese traditional -> Chinese simplify (v1) # Chinese traditional -> Chinese simplify (v1)
def zh_simplify(line): @wrap_str_func
def zh_simplify(line, server_model):
return HanLP.convertToSimplifiedChinese(line) return HanLP.convertToSimplifiedChinese(line)
# Chinese traditional -> Chinese simplify (v2) # Chinese traditional -> Chinese simplify (v2)
def zh_simplify_v2(line): @wrap_str_func
def zh_simplify_v2(line, server_model):
return SnowNLP(line).han return SnowNLP(line).han

Просмотреть файл

@ -3,6 +3,7 @@ from __future__ import unicode_literals, print_function
import torch import torch
from onmt.inputters.text_dataset import TextMultiField from onmt.inputters.text_dataset import TextMultiField
from onmt.utils.alignment import build_align_pharaoh
class TranslationBuilder(object): class TranslationBuilder(object):
@ -62,14 +63,18 @@ class TranslationBuilder(object):
len(translation_batch["predictions"])) len(translation_batch["predictions"]))
batch_size = batch.batch_size batch_size = batch.batch_size
preds, pred_score, attn, gold_score, indices = list(zip( preds, pred_score, attn, align, gold_score, indices = list(zip(
*sorted(zip(translation_batch["predictions"], *sorted(zip(translation_batch["predictions"],
translation_batch["scores"], translation_batch["scores"],
translation_batch["attention"], translation_batch["attention"],
translation_batch["alignment"],
translation_batch["gold_score"], translation_batch["gold_score"],
batch.indices.data), batch.indices.data),
key=lambda x: x[-1]))) key=lambda x: x[-1])))
if not any(align): # when align is a empty nested list
align = [None] * batch_size
# Sorting # Sorting
inds, perm = torch.sort(batch.indices) inds, perm = torch.sort(batch.indices)
if self._has_text_src: if self._has_text_src:
@ -91,7 +96,8 @@ class TranslationBuilder(object):
pred_sents = [self._build_target_tokens( pred_sents = [self._build_target_tokens(
src[:, b] if src is not None else None, src[:, b] if src is not None else None,
src_vocab, src_raw, src_vocab, src_raw,
preds[b][n], attn[b][n]) preds[b][n],
align[b][n] if align[b] is not None else attn[b][n])
for n in range(self.n_best)] for n in range(self.n_best)]
gold_sent = None gold_sent = None
if tgt is not None: if tgt is not None:
@ -103,7 +109,7 @@ class TranslationBuilder(object):
translation = Translation( translation = Translation(
src[:, b] if src is not None else None, src[:, b] if src is not None else None,
src_raw, pred_sents, attn[b], pred_score[b], src_raw, pred_sents, attn[b], pred_score[b],
gold_sent, gold_score[b] gold_sent, gold_score[b], align[b]
) )
translations.append(translation) translations.append(translation)
@ -122,13 +128,15 @@ class Translation(object):
translation. translation.
gold_sent (List[str]): Words from gold translation. gold_sent (List[str]): Words from gold translation.
gold_score (List[float]): Log-prob of gold translation. gold_score (List[float]): Log-prob of gold translation.
word_aligns (List[FloatTensor]): Words Alignment distribution for
each translation.
""" """
__slots__ = ["src", "src_raw", "pred_sents", "attns", "pred_scores", __slots__ = ["src", "src_raw", "pred_sents", "attns", "pred_scores",
"gold_sent", "gold_score"] "gold_sent", "gold_score", "word_aligns"]
def __init__(self, src, src_raw, pred_sents, def __init__(self, src, src_raw, pred_sents,
attn, pred_scores, tgt_sent, gold_score): attn, pred_scores, tgt_sent, gold_score, word_aligns):
self.src = src self.src = src
self.src_raw = src_raw self.src_raw = src_raw
self.pred_sents = pred_sents self.pred_sents = pred_sents
@ -136,6 +144,7 @@ class Translation(object):
self.pred_scores = pred_scores self.pred_scores = pred_scores
self.gold_sent = tgt_sent self.gold_sent = tgt_sent
self.gold_score = gold_score self.gold_score = gold_score
self.word_aligns = word_aligns
def log(self, sent_number): def log(self, sent_number):
""" """
@ -150,6 +159,12 @@ class Translation(object):
msg.append('PRED {}: {}\n'.format(sent_number, pred_sent)) msg.append('PRED {}: {}\n'.format(sent_number, pred_sent))
msg.append("PRED SCORE: {:.4f}\n".format(best_score)) msg.append("PRED SCORE: {:.4f}\n".format(best_score))
if self.word_aligns is not None:
pred_align = self.word_aligns[0]
pred_align_pharaoh = build_align_pharaoh(pred_align)
pred_align_sent = ' '.join(pred_align_pharaoh)
msg.append("ALIGN: {}\n".format(pred_align_sent))
if self.gold_sent is not None: if self.gold_sent is not None:
tgt_sent = ' '.join(self.gold_sent) tgt_sent = ' '.join(self.gold_sent)
msg.append('GOLD {}: {}\n'.format(sent_number, tgt_sent)) msg.append('GOLD {}: {}\n'.format(sent_number, tgt_sent))

Просмотреть файл

@ -13,8 +13,13 @@ import importlib
import torch import torch
import onmt.opts import onmt.opts
from itertools import islice
from copy import deepcopy
from onmt.utils.logging import init_logger from onmt.utils.logging import init_logger
from onmt.utils.misc import set_random_seed from onmt.utils.misc import set_random_seed
from onmt.utils.misc import check_model_config
from onmt.utils.alignment import to_word_align
from onmt.utils.parse import ArgumentParser from onmt.utils.parse import ArgumentParser
from onmt.translate.translator import build_translator from onmt.translate.translator import build_translator
@ -69,6 +74,53 @@ class ServerModelError(Exception):
pass pass
class CTranslate2Translator(object):
"""
This class wraps the ctranslate2.Translator object to
reproduce the onmt.translate.translator API.
"""
def __init__(self, model_path, device, device_index,
batch_size, beam_size, n_best, preload=False):
import ctranslate2
self.translator = ctranslate2.Translator(
model_path,
device=device,
device_index=device_index,
inter_threads=1,
intra_threads=1,
compute_type="default")
self.batch_size = batch_size
self.beam_size = beam_size
self.n_best = n_best
if preload:
# perform a first request to initialize everything
dummy_translation = self.translate(["a"])
print("Performed a dummy translation to initialize the model",
dummy_translation)
time.sleep(1)
self.translator.unload_model(to_cpu=True)
def translate(self, texts_to_translate, batch_size=8):
batch = [item.split(" ") for item in texts_to_translate]
preds = self.translator.translate_batch(
batch,
max_batch_size=self.batch_size,
beam_size=self.beam_size,
num_hypotheses=self.n_best
)
scores = [[item["score"] for item in ex] for ex in preds]
predictions = [[" ".join(item["tokens"]) for item in ex]
for ex in preds]
return scores, predictions
def to_cpu(self):
self.translator.unload_model(to_cpu=True)
def to_gpu(self):
self.translator.load_model()
class TranslationServer(object): class TranslationServer(object):
def __init__(self): def __init__(self):
self.models = {} self.models = {}
@ -89,13 +141,15 @@ class TranslationServer(object):
else: else:
raise ValueError("""Incorrect config file: missing 'models' raise ValueError("""Incorrect config file: missing 'models'
parameter for model #%d""" % i) parameter for model #%d""" % i)
check_model_config(conf, self.models_root)
kwargs = {'timeout': conf.get('timeout', None), kwargs = {'timeout': conf.get('timeout', None),
'load': conf.get('load', None), 'load': conf.get('load', None),
'preprocess_opt': conf.get('preprocess', None), 'preprocess_opt': conf.get('preprocess', None),
'tokenizer_opt': conf.get('tokenizer', None), 'tokenizer_opt': conf.get('tokenizer', None),
'postprocess_opt': conf.get('postprocess', None), 'postprocess_opt': conf.get('postprocess', None),
'on_timeout': conf.get('on_timeout', None), 'on_timeout': conf.get('on_timeout', None),
'model_root': conf.get('model_root', self.models_root) 'model_root': conf.get('model_root', self.models_root),
'ct2_model': conf.get('ct2_model', None)
} }
kwargs = {k: v for (k, v) in kwargs.items() if v is not None} kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
model_id = conf.get("id", None) model_id = conf.get("id", None)
@ -202,11 +256,9 @@ class ServerModel(object):
def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None, def __init__(self, opt, model_id, preprocess_opt=None, tokenizer_opt=None,
postprocess_opt=None, load=False, timeout=-1, postprocess_opt=None, load=False, timeout=-1,
on_timeout="to_cpu", model_root="./"): on_timeout="to_cpu", model_root="./", ct2_model=None):
self.model_root = model_root self.model_root = model_root
self.opt = self.parse_opt(opt) self.opt = self.parse_opt(opt)
if self.opt.n_best > 1:
raise ValueError("Values of n_best > 1 are not supported")
self.model_id = model_id self.model_id = model_id
self.preprocess_opt = preprocess_opt self.preprocess_opt = preprocess_opt
@ -215,6 +267,9 @@ class ServerModel(object):
self.timeout = timeout self.timeout = timeout
self.on_timeout = on_timeout self.on_timeout = on_timeout
self.ct2_model = os.path.join(model_root, ct2_model) \
if ct2_model is not None else None
self.unload_timer = None self.unload_timer = None
self.user_opt = opt self.user_opt = opt
self.tokenizer = None self.tokenizer = None
@ -224,7 +279,8 @@ class ServerModel(object):
else: else:
log_file = None log_file = None
self.logger = init_logger(log_file=log_file, self.logger = init_logger(log_file=log_file,
log_file_level=self.opt.log_file_level) log_file_level=self.opt.log_file_level,
rotate=True)
self.loading_lock = threading.Event() self.loading_lock = threading.Event()
self.loading_lock.set() self.loading_lock.set()
@ -232,67 +288,6 @@ class ServerModel(object):
set_random_seed(self.opt.seed, self.opt.cuda) set_random_seed(self.opt.seed, self.opt.cuda)
if load:
self.load()
def parse_opt(self, opt):
"""Parse the option set passed by the user using `onmt.opts`
Args:
opt (dict): Options passed by the user
Returns:
opt (argparse.Namespace): full set of options for the Translator
"""
prec_argv = sys.argv
sys.argv = sys.argv[:1]
parser = ArgumentParser()
onmt.opts.translate_opts(parser)
models = opt['models']
if not isinstance(models, (list, tuple)):
models = [models]
opt['models'] = [os.path.join(self.model_root, model)
for model in models]
opt['src'] = "dummy_src"
for (k, v) in opt.items():
if k == 'models':
sys.argv += ['-model']
sys.argv += [str(model) for model in v]
elif type(v) == bool:
sys.argv += ['-%s' % k]
else:
sys.argv += ['-%s' % k, str(v)]
opt = parser.parse_args()
ArgumentParser.validate_translate_opts(opt)
opt.cuda = opt.gpu > -1
sys.argv = prec_argv
return opt
@property
def loaded(self):
return hasattr(self, 'translator')
def load(self):
self.loading_lock.clear()
timer = Timer()
self.logger.info("Loading model %d" % self.model_id)
timer.start()
try:
self.translator = build_translator(self.opt,
report_score=False,
out_file=codecs.open(
os.devnull, "w", "utf-8"))
except RuntimeError as e:
raise ServerModelError("Runtime Error: %s" % str(e))
timer.tick("model_loading")
if self.preprocess_opt is not None: if self.preprocess_opt is not None:
self.logger.info("Loading preprocessor") self.logger.info("Loading preprocessor")
self.preprocessor = [] self.preprocessor = []
@ -347,6 +342,77 @@ class ServerModel(object):
function = get_function_by_path(function_path) function = get_function_by_path(function_path)
self.postprocessor.append(function) self.postprocessor.append(function)
if load:
self.load(preload=True)
self.stop_unload_timer()
def parse_opt(self, opt):
"""Parse the option set passed by the user using `onmt.opts`
Args:
opt (dict): Options passed by the user
Returns:
opt (argparse.Namespace): full set of options for the Translator
"""
prec_argv = sys.argv
sys.argv = sys.argv[:1]
parser = ArgumentParser()
onmt.opts.translate_opts(parser)
models = opt['models']
if not isinstance(models, (list, tuple)):
models = [models]
opt['models'] = [os.path.join(self.model_root, model)
for model in models]
opt['src'] = "dummy_src"
for (k, v) in opt.items():
if k == 'models':
sys.argv += ['-model']
sys.argv += [str(model) for model in v]
elif type(v) == bool:
sys.argv += ['-%s' % k]
else:
sys.argv += ['-%s' % k, str(v)]
opt = parser.parse_args()
ArgumentParser.validate_translate_opts(opt)
opt.cuda = opt.gpu > -1
sys.argv = prec_argv
return opt
@property
def loaded(self):
return hasattr(self, 'translator')
def load(self, preload=False):
self.loading_lock.clear()
timer = Timer()
self.logger.info("Loading model %d" % self.model_id)
timer.start()
try:
if self.ct2_model is not None:
self.translator = CTranslate2Translator(
self.ct2_model,
device="cuda" if self.opt.cuda else "cpu",
device_index=self.opt.gpu if self.opt.cuda else 0,
batch_size=self.opt.batch_size,
beam_size=self.opt.beam_size,
n_best=self.opt.n_best,
preload=preload)
else:
self.translator = build_translator(
self.opt, report_score=False,
out_file=codecs.open(os.devnull, "w", "utf-8"))
except RuntimeError as e:
raise ServerModelError("Runtime Error: %s" % str(e))
timer.tick("model_loading")
self.load_time = timer.tick() self.load_time = timer.tick()
self.reset_unload_timer() self.reset_unload_timer()
self.loading_lock.set() self.loading_lock.set()
@ -390,13 +456,9 @@ class ServerModel(object):
head_spaces = [] head_spaces = []
tail_spaces = [] tail_spaces = []
sslength = [] sslength = []
all_preprocessed = []
for i, inp in enumerate(inputs): for i, inp in enumerate(inputs):
src = inp['src'] src = inp['src']
if src.strip() == "":
head_spaces.append(src)
texts.append("")
tail_spaces.append("")
else:
whitespaces_before, whitespaces_after = "", "" whitespaces_before, whitespaces_after = "", ""
match_before = re.search(r'^\s+', src) match_before = re.search(r'^\s+', src)
match_after = re.search(r'\s+$', src) match_after = re.search(r'\s+$', src)
@ -405,8 +467,11 @@ class ServerModel(object):
if match_after is not None: if match_after is not None:
whitespaces_after = match_after.group(0) whitespaces_after = match_after.group(0)
head_spaces.append(whitespaces_before) head_spaces.append(whitespaces_before)
preprocessed_src = self.maybe_preprocess(src.strip()) # every segment becomes a dict for flexibility purposes
tok = self.maybe_tokenize(preprocessed_src) seg_dict = self.maybe_preprocess(src.strip())
all_preprocessed.append(seg_dict)
for seg in seg_dict["seg"]:
tok = self.maybe_tokenize(seg)
texts.append(tok) texts.append(tok)
sslength.append(len(tok.split())) sslength.append(len(tok.split()))
tail_spaces.append(whitespaces_after) tail_spaces.append(whitespaces_after)
@ -441,28 +506,69 @@ class ServerModel(object):
self.reset_unload_timer() self.reset_unload_timer()
# NOTE: translator returns lists of `n_best` list # NOTE: translator returns lists of `n_best` list
# we can ignore that (i.e. flatten lists) only because
# we restrict `n_best=1`
def flatten_list(_list): return sum(_list, []) def flatten_list(_list): return sum(_list, [])
tiled_texts = [t for t in texts_to_translate
for _ in range(self.opt.n_best)]
results = flatten_list(predictions) results = flatten_list(predictions)
scores = [score_tensor.item()
def maybe_item(x): return x.item() if type(x) is torch.Tensor else x
scores = [maybe_item(score_tensor)
for score_tensor in flatten_list(scores)] for score_tensor in flatten_list(scores)]
results = [self.maybe_detokenize(item) results = [self.maybe_detokenize_with_align(result, src)
for item in results] for result, src in zip(results, tiled_texts)]
aligns = [align for _, align in results]
results = [self.maybe_postprocess(item)
for item in results]
# build back results with empty texts # build back results with empty texts
for i in empty_indices: for i in empty_indices:
results.insert(i, "") j = i * self.opt.n_best
scores.insert(i, 0) results = (results[:j] +
[("", None)] * self.opt.n_best + results[j:])
aligns = aligns[:j] + [None] * self.opt.n_best + aligns[j:]
scores = scores[:j] + [0] * self.opt.n_best + scores[j:]
rebuilt_segs, scores, aligns = self.rebuild_seg_packages(
all_preprocessed, results, scores, aligns, self.opt.n_best)
results = [self.maybe_postprocess(seg) for seg in rebuilt_segs]
head_spaces = [h for h in head_spaces for i in range(self.opt.n_best)]
tail_spaces = [h for h in tail_spaces for i in range(self.opt.n_best)]
results = ["".join(items) results = ["".join(items)
for items in zip(head_spaces, results, tail_spaces)] for items in zip(head_spaces, results, tail_spaces)]
self.logger.info("Translation Results: %d", len(results)) self.logger.info("Translation Results: %d", len(results))
return results, scores, self.opt.n_best, timer.times
return results, scores, self.opt.n_best, timer.times, aligns
def rebuild_seg_packages(self, all_preprocessed, results,
scores, aligns, n_best):
"""
Rebuild proper segment packages based on initial n_seg.
"""
offset = 0
rebuilt_segs = []
avg_scores = []
merged_aligns = []
for i, seg_dict in enumerate(all_preprocessed):
sub_results = results[n_best * offset:
(offset + seg_dict["n_seg"]) * n_best]
sub_scores = scores[n_best * offset:
(offset + seg_dict["n_seg"]) * n_best]
sub_aligns = aligns[n_best * offset:
(offset + seg_dict["n_seg"]) * n_best]
for j in range(n_best):
_seg_dict = deepcopy(seg_dict)
_sub_segs = list(list(zip(*sub_results))[0])
_seg_dict["seg"] = list(islice(_sub_segs, j, None, n_best))
rebuilt_segs.append(_seg_dict)
sub_sub_scores = list(islice(sub_scores, j, None, n_best))
avg_scores.append(sum(sub_sub_scores)/_seg_dict["n_seg"])
sub_sub_aligns = list(islice(sub_aligns, j, None, n_best))
merged_aligns.append(sub_sub_aligns)
offset += _seg_dict["n_seg"]
return rebuilt_segs, avg_scores, merged_aligns
def do_timeout(self): def do_timeout(self):
"""Timeout function that frees GPU memory. """Timeout function that frees GPU memory.
@ -485,6 +591,7 @@ class ServerModel(object):
del self.translator del self.translator
if self.opt.cuda: if self.opt.cuda:
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.stop_unload_timer()
self.unload_timer = None self.unload_timer = None
def stop_unload_timer(self): def stop_unload_timer(self):
@ -515,12 +622,18 @@ class ServerModel(object):
@critical @critical
def to_cpu(self): def to_cpu(self):
"""Move the model to CPU and clear CUDA cache.""" """Move the model to CPU and clear CUDA cache."""
if type(self.translator) == CTranslate2Translator:
self.translator.to_cpu()
else:
self.translator.model.cpu() self.translator.model.cpu()
if self.opt.cuda: if self.opt.cuda:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def to_gpu(self): def to_gpu(self):
"""Move the model to GPU.""" """Move the model to GPU."""
if type(self.translator) == CTranslate2Translator:
self.translator.to_gpu()
else:
torch.cuda.set_device(self.opt.gpu) torch.cuda.set_device(self.opt.gpu)
self.translator.model.cuda() self.translator.model.cuda()
@ -528,7 +641,11 @@ class ServerModel(object):
"""Preprocess the sequence (or not) """Preprocess the sequence (or not)
""" """
if type(sequence) is str:
sequence = {
"seg": [sequence],
"n_seg": 1
}
if self.preprocess_opt is not None: if self.preprocess_opt is not None:
return self.preprocess(sequence) return self.preprocess(sequence)
return sequence return sequence
@ -545,7 +662,7 @@ class ServerModel(object):
if self.preprocessor is None: if self.preprocessor is None:
raise ValueError("No preprocessor loaded") raise ValueError("No preprocessor loaded")
for function in self.preprocessor: for function in self.preprocessor:
sequence = function(sequence) sequence = function(sequence, self)
return sequence return sequence
def maybe_tokenize(self, sequence): def maybe_tokenize(self, sequence):
@ -579,6 +696,42 @@ class ServerModel(object):
tok = " ".join(tok) tok = " ".join(tok)
return tok return tok
@property
def tokenizer_marker(self):
marker = None
tokenizer_type = self.tokenizer_opt.get('type', None)
if tokenizer_type == "pyonmttok":
params = self.tokenizer_opt.get('params', None)
if params is not None:
if params.get("joiner_annotate", None) is not None:
marker = 'joiner'
elif params.get("spacer_annotate", None) is not None:
marker = 'spacer'
elif tokenizer_type == "sentencepiece":
marker = 'spacer'
return marker
def maybe_detokenize_with_align(self, sequence, src):
"""De-tokenize (or not) the sequence (with alignment).
Args:
sequence (str): The sequence to detokenize, possible with
alignment seperate by ` ||| `.
Returns:
sequence (str): The detokenized sequence.
align (str): The alignment correspand to detokenized src/tgt
sorted or None if no alignment in output.
"""
align = None
if self.opt.report_align:
# output contain alignment
sequence, align = sequence.split(' ||| ')
if align != '':
align = self.maybe_convert_align(src, sequence, align)
sequence = self.maybe_detokenize(sequence)
return (sequence, align)
def maybe_detokenize(self, sequence): def maybe_detokenize(self, sequence):
"""De-tokenize the sequence (or not) """De-tokenize the sequence (or not)
@ -605,14 +758,29 @@ class ServerModel(object):
return detok return detok
def maybe_convert_align(self, src, tgt, align):
"""Convert alignment to match detokenized src/tgt (or not).
Args:
src (str): The tokenized source sequence.
tgt (str): The tokenized target sequence.
align (str): The alignment correspand to src/tgt pair.
Returns:
align (str): The alignment correspand to detokenized src/tgt.
"""
if self.tokenizer_marker is not None and ''.join(tgt.split()) != '':
return to_word_align(src, tgt, align, mode=self.tokenizer_marker)
return align
def maybe_postprocess(self, sequence): def maybe_postprocess(self, sequence):
"""Postprocess the sequence (or not) """Postprocess the sequence (or not)
""" """
if self.postprocess_opt is not None: if self.postprocess_opt is not None:
return self.postprocess(sequence) return self.postprocess(sequence)
return sequence else:
return sequence["seg"][0]
def postprocess(self, sequence): def postprocess(self, sequence):
"""Preprocess a single sequence. """Preprocess a single sequence.
@ -626,7 +794,7 @@ class ServerModel(object):
if self.postprocessor is None: if self.postprocessor is None:
raise ValueError("No postprocessor loaded") raise ValueError("No postprocessor loaded")
for function in self.postprocessor: for function in self.postprocessor:
sequence = function(sequence) sequence = function(sequence, self)
return sequence return sequence

Просмотреть файл

@ -3,19 +3,19 @@
from __future__ import print_function from __future__ import print_function
import codecs import codecs
import os import os
import math
import time import time
from itertools import count import numpy as np
from itertools import count, zip_longest
import torch import torch
import onmt.model_builder import onmt.model_builder
import onmt.translate.beam
import onmt.inputters as inputters import onmt.inputters as inputters
import onmt.decoders.ensemble import onmt.decoders.ensemble
from onmt.translate.beam_search import BeamSearch from onmt.translate.beam_search import BeamSearch
from onmt.translate.random_sampling import RandomSampling from onmt.translate.greedy_search import GreedySearch
from onmt.utils.misc import tile, set_random_seed from onmt.utils.misc import tile, set_random_seed, report_matrix
from onmt.utils.alignment import extract_alignment, build_align_pharaoh
from onmt.modules.copy_generator import collapse_copy_scores from onmt.modules.copy_generator import collapse_copy_scores
@ -36,12 +36,32 @@ def build_translator(opt, report_score=True, logger=None, out_file=None):
model_opt, model_opt,
global_scorer=scorer, global_scorer=scorer,
out_file=out_file, out_file=out_file,
report_align=opt.report_align,
report_score=report_score, report_score=report_score,
logger=logger logger=logger
) )
return translator return translator
def max_tok_len(new, count, sofar):
"""
In token batching scheme, the number of sequences is limited
such that the total number of src/tgt tokens (including padding)
in a batch <= batch_size
"""
# Maintains the longest src and tgt length in the current batch
global max_src_in_batch # this is a hack
# Reset current longest length at a new batch (count=1)
if count == 1:
max_src_in_batch = 0
# max_tgt_in_batch = 0
# Src: [<bos> w1 ... wN <eos>]
max_src_in_batch = max(max_src_in_batch, len(new.src[0]) + 2)
# Tgt: [w1 ... wM <eos>]
src_elements = count * max_src_in_batch
return src_elements
class Translator(object): class Translator(object):
"""Translate a batch of sentences with a saved model. """Translate a batch of sentences with a saved model.
@ -59,9 +79,9 @@ class Translator(object):
:class:`onmt.translate.decode_strategy.DecodeStrategy`. :class:`onmt.translate.decode_strategy.DecodeStrategy`.
beam_size (int): Number of beams. beam_size (int): Number of beams.
random_sampling_topk (int): See random_sampling_topk (int): See
:class:`onmt.translate.random_sampling.RandomSampling`. :class:`onmt.translate.greedy_search.GreedySearch`.
random_sampling_temp (int): See random_sampling_temp (int): See
:class:`onmt.translate.random_sampling.RandomSampling`. :class:`onmt.translate.greedy_search.GreedySearch`.
stepwise_penalty (bool): Whether coverage penalty is applied every step stepwise_penalty (bool): Whether coverage penalty is applied every step
or not. or not.
dump_beam (bool): Debugging option. dump_beam (bool): Debugging option.
@ -72,8 +92,6 @@ class Translator(object):
replace_unk (bool): Replace unknown token. replace_unk (bool): Replace unknown token.
data_type (str): Source data type. data_type (str): Source data type.
verbose (bool): Print/log every translation. verbose (bool): Print/log every translation.
report_bleu (bool): Print/log Bleu metric.
report_rouge (bool): Print/log Rouge metric.
report_time (bool): Print/log total time/frequency. report_time (bool): Print/log total time/frequency.
copy_attn (bool): Use copy attention. copy_attn (bool): Use copy attention.
global_scorer (onmt.translate.GNMTGlobalScorer): Translation global_scorer (onmt.translate.GNMTGlobalScorer): Translation
@ -105,12 +123,11 @@ class Translator(object):
phrase_table="", phrase_table="",
data_type="text", data_type="text",
verbose=False, verbose=False,
report_bleu=False,
report_rouge=False,
report_time=False, report_time=False,
copy_attn=False, copy_attn=False,
global_scorer=None, global_scorer=None,
out_file=None, out_file=None,
report_align=False,
report_score=True, report_score=True,
logger=None, logger=None,
seed=-1): seed=-1):
@ -153,8 +170,6 @@ class Translator(object):
self.phrase_table = phrase_table self.phrase_table = phrase_table
self.data_type = data_type self.data_type = data_type
self.verbose = verbose self.verbose = verbose
self.report_bleu = report_bleu
self.report_rouge = report_rouge
self.report_time = report_time self.report_time = report_time
self.copy_attn = copy_attn self.copy_attn = copy_attn
@ -165,6 +180,7 @@ class Translator(object):
raise ValueError( raise ValueError(
"Coverage penalty requires an attentional decoder.") "Coverage penalty requires an attentional decoder.")
self.out_file = out_file self.out_file = out_file
self.report_align = report_align
self.report_score = report_score self.report_score = report_score
self.logger = logger self.logger = logger
@ -192,6 +208,7 @@ class Translator(object):
model_opt, model_opt,
global_scorer=None, global_scorer=None,
out_file=None, out_file=None,
report_align=False,
report_score=True, report_score=True,
logger=None): logger=None):
"""Alternate constructor. """Alternate constructor.
@ -207,6 +224,7 @@ class Translator(object):
:func:`__init__()`.. :func:`__init__()`..
out_file (TextIO or codecs.StreamReaderWriter): See out_file (TextIO or codecs.StreamReaderWriter): See
:func:`__init__()`. :func:`__init__()`.
report_align (bool) : See :func:`__init__()`.
report_score (bool) : See :func:`__init__()`. report_score (bool) : See :func:`__init__()`.
logger (logging.Logger or NoneType): See :func:`__init__()`. logger (logging.Logger or NoneType): See :func:`__init__()`.
""" """
@ -234,12 +252,11 @@ class Translator(object):
phrase_table=opt.phrase_table, phrase_table=opt.phrase_table,
data_type=opt.data_type, data_type=opt.data_type,
verbose=opt.verbose, verbose=opt.verbose,
report_bleu=opt.report_bleu,
report_rouge=opt.report_rouge,
report_time=opt.report_time, report_time=opt.report_time,
copy_attn=model_opt.copy_attn, copy_attn=model_opt.copy_attn,
global_scorer=global_scorer, global_scorer=global_scorer,
out_file=out_file, out_file=out_file,
report_align=report_align,
report_score=report_score, report_score=report_score,
logger=logger, logger=logger,
seed=opt.seed) seed=opt.seed)
@ -267,7 +284,9 @@ class Translator(object):
tgt=None, tgt=None,
src_dir=None, src_dir=None,
batch_size=None, batch_size=None,
batch_type="sents",
attn_debug=False, attn_debug=False,
align_debug=False,
phrase_table=""): phrase_table=""):
"""Translate content of ``src`` and get gold scores from ``tgt``. """Translate content of ``src`` and get gold scores from ``tgt``.
@ -278,6 +297,7 @@ class Translator(object):
for certain types of data). for certain types of data).
batch_size (int): size of examples per mini-batch batch_size (int): size of examples per mini-batch
attn_debug (bool): enables the attention logging attn_debug (bool): enables the attention logging
align_debug (bool): enables the word alignment logging
Returns: Returns:
(`list`, `list`) (`list`, `list`)
@ -290,12 +310,16 @@ class Translator(object):
if batch_size is None: if batch_size is None:
raise ValueError("batch_size must be set") raise ValueError("batch_size must be set")
src_data = {"reader": self.src_reader, "data": src, "dir": src_dir}
tgt_data = {"reader": self.tgt_reader, "data": tgt, "dir": None}
_readers, _data, _dir = inputters.Dataset.config(
[('src', src_data), ('tgt', tgt_data)])
# corpus_id field is useless here
if self.fields.get("corpus_id", None) is not None:
self.fields.pop('corpus_id')
data = inputters.Dataset( data = inputters.Dataset(
self.fields, self.fields, readers=_readers, data=_data, dirs=_dir,
readers=([self.src_reader, self.tgt_reader]
if tgt else [self.src_reader]),
data=[("src", src), ("tgt", tgt)] if tgt else [("src", src)],
dirs=[src_dir, None] if tgt else [src_dir],
sort_key=inputters.str2sortkey[self.data_type], sort_key=inputters.str2sortkey[self.data_type],
filter_pred=self._filter_pred filter_pred=self._filter_pred
) )
@ -304,6 +328,7 @@ class Translator(object):
dataset=data, dataset=data,
device=self._dev, device=self._dev,
batch_size=batch_size, batch_size=batch_size,
batch_size_fn=max_tok_len if batch_type == "tokens" else None,
train=False, train=False,
sort=False, sort=False,
sort_within_batch=True, sort_within_batch=True,
@ -341,6 +366,14 @@ class Translator(object):
n_best_preds = [" ".join(pred) n_best_preds = [" ".join(pred)
for pred in trans.pred_sents[:self.n_best]] for pred in trans.pred_sents[:self.n_best]]
if self.report_align:
align_pharaohs = [build_align_pharaoh(align) for align
in trans.word_aligns[:self.n_best]]
n_best_preds_align = [" ".join(align) for align
in align_pharaohs]
n_best_preds = [pred + " ||| " + align
for pred, align in zip(
n_best_preds, n_best_preds_align)]
all_predictions += [n_best_preds] all_predictions += [n_best_preds]
self.out_file.write('\n'.join(n_best_preds) + '\n') self.out_file.write('\n'.join(n_best_preds) + '\n')
self.out_file.flush() self.out_file.flush()
@ -361,17 +394,23 @@ class Translator(object):
srcs = trans.src_raw srcs = trans.src_raw
else: else:
srcs = [str(item) for item in range(len(attns[0]))] srcs = [str(item) for item in range(len(attns[0]))]
header_format = "{:>10.10} " + "{:>10.7} " * len(srcs) output = report_matrix(srcs, preds, attns)
row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) if self.logger:
output = header_format.format("", *srcs) + '\n' self.logger.info(output)
for word, row in zip(preds, attns): else:
max_index = row.index(max(row)) os.write(1, output.encode('utf-8'))
row_format = row_format.replace(
"{:>10.7f} ", "{:*>10.7f} ", max_index + 1) if align_debug:
row_format = row_format.replace( if trans.gold_sent is not None:
"{:*>10.7f} ", "{:>10.7f} ", max_index) tgts = trans.gold_sent
output += row_format.format(word, *row) + '\n' else:
row_format = "{:>10.10} " + "{:>10.7f} " * len(srcs) tgts = trans.pred_sents[0]
align = trans.word_aligns[0].tolist()
if self.data_type == 'text':
srcs = trans.src_raw
else:
srcs = [str(item) for item in range(len(align[0]))]
output = report_matrix(srcs, tgts, align)
if self.logger: if self.logger:
self.logger.info(output) self.logger.info(output)
else: else:
@ -387,12 +426,6 @@ class Translator(object):
msg = self._report_score('GOLD', gold_score_total, msg = self._report_score('GOLD', gold_score_total,
gold_words_total) gold_words_total)
self._log(msg) self._log(msg)
if self.report_bleu:
msg = self._report_bleu(tgt)
self._log(msg)
if self.report_rouge:
msg = self._report_rouge(tgt)
self._log(msg)
if self.report_time: if self.report_time:
total_time = end_time - start_time total_time = end_time - start_time
@ -408,119 +441,113 @@ class Translator(object):
codecs.open(self.dump_beam, 'w', 'utf-8')) codecs.open(self.dump_beam, 'w', 'utf-8'))
return all_scores, all_predictions return all_scores, all_predictions
def _translate_random_sampling( def _align_pad_prediction(self, predictions, bos, pad):
self, """
batch, Padding predictions in batch and add BOS.
src_vocabs,
max_length,
min_length=0,
sampling_temp=1.0,
keep_topk=-1,
return_attention=False):
"""Alternative to beam search. Do random sampling at each step."""
assert self.beam_size == 1 Args:
predictions (List[List[Tensor]]): `(batch, n_best,)`, for each src
sequence contain n_best tgt predictions all of which ended with
eos id.
bos (int): bos index to be used.
pad (int): pad index to be used.
# TODO: support these blacklisted features. Return:
assert self.block_ngram_repeat == 0 batched_nbest_predict (torch.LongTensor): `(batch, n_best, tgt_l)`
"""
dtype, device = predictions[0][0].dtype, predictions[0][0].device
flatten_tgt = [best.tolist() for bests in predictions
for best in bests]
paded_tgt = torch.tensor(
list(zip_longest(*flatten_tgt, fillvalue=pad)),
dtype=dtype, device=device).T
bos_tensor = torch.full([paded_tgt.size(0), 1], bos,
dtype=dtype, device=device)
full_tgt = torch.cat((bos_tensor, paded_tgt), dim=-1)
batched_nbest_predict = full_tgt.view(
len(predictions), -1, full_tgt.size(-1)) # (batch, n_best, tgt_l)
return batched_nbest_predict
batch_size = batch.batch_size def _align_forward(self, batch, predictions):
"""
For a batch of input and its prediction, return a list of batch predict
alignment src indice Tensor in size ``(batch, n_best,)``.
"""
# (0) add BOS and padding to tgt prediction
if hasattr(batch, 'tgt'):
batch_tgt_idxs = batch.tgt.transpose(1, 2).transpose(0, 2)
else:
batch_tgt_idxs = self._align_pad_prediction(
predictions, bos=self._tgt_bos_idx, pad=self._tgt_pad_idx)
tgt_mask = (batch_tgt_idxs.eq(self._tgt_pad_idx) |
batch_tgt_idxs.eq(self._tgt_eos_idx) |
batch_tgt_idxs.eq(self._tgt_bos_idx))
# Encoder forward. n_best = batch_tgt_idxs.size(1)
# (1) Encoder forward.
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
# (2) Repeat src objects `n_best` times.
# We use batch_size x n_best, get ``(src_len, batch * n_best, nfeat)``
src = tile(src, n_best, dim=1)
enc_states = tile(enc_states, n_best, dim=1)
if isinstance(memory_bank, tuple):
memory_bank = tuple(tile(x, n_best, dim=1) for x in memory_bank)
else:
memory_bank = tile(memory_bank, n_best, dim=1)
src_lengths = tile(src_lengths, n_best) # ``(batch * n_best,)``
# (3) Init decoder with n_best src,
self.model.decoder.init_state(src, memory_bank, enc_states) self.model.decoder.init_state(src, memory_bank, enc_states)
# reshape tgt to ``(len, batch * n_best, nfeat)``
tgt = batch_tgt_idxs.view(-1, batch_tgt_idxs.size(-1)).T.unsqueeze(-1)
dec_in = tgt[:-1] # exclude last target from inputs
_, attns = self.model.decoder(
dec_in, memory_bank, memory_lengths=src_lengths, with_align=True)
use_src_map = self.copy_attn alignment_attn = attns["align"] # ``(B, tgt_len-1, src_len)``
# masked_select
results = { align_tgt_mask = tgt_mask.view(-1, tgt_mask.size(-1))
"predictions": None, prediction_mask = align_tgt_mask[:, 1:] # exclude bos to match pred
"scores": None, # get aligned src id for each prediction's valid tgt tokens
"attention": None, alignement = extract_alignment(
"batch": batch, alignment_attn, prediction_mask, src_lengths, n_best)
"gold_score": self._gold_score( return alignement
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
enc_states, batch_size, src)}
memory_lengths = src_lengths
src_map = batch.src_map if use_src_map else None
if isinstance(memory_bank, tuple):
mb_device = memory_bank[0].device
else:
mb_device = memory_bank.device
random_sampler = RandomSampling(
self._tgt_pad_idx, self._tgt_bos_idx, self._tgt_eos_idx,
batch_size, mb_device, min_length, self.block_ngram_repeat,
self._exclusion_idxs, return_attention, self.max_length,
sampling_temp, keep_topk, memory_lengths)
for step in range(max_length):
# Shape: (1, B, 1)
decoder_input = random_sampler.alive_seq[:, -1].view(1, -1, 1)
log_probs, attn = self._decode_and_generate(
decoder_input,
memory_bank,
batch,
src_vocabs,
memory_lengths=memory_lengths,
src_map=src_map,
step=step,
batch_offset=random_sampler.select_indices
)
random_sampler.advance(log_probs, attn)
any_batch_is_finished = random_sampler.is_finished.any()
if any_batch_is_finished:
random_sampler.update_finished()
if random_sampler.done:
break
if any_batch_is_finished:
select_indices = random_sampler.select_indices
# Reorder states.
if isinstance(memory_bank, tuple):
memory_bank = tuple(x.index_select(1, select_indices)
for x in memory_bank)
else:
memory_bank = memory_bank.index_select(1, select_indices)
memory_lengths = memory_lengths.index_select(0, select_indices)
if src_map is not None:
src_map = src_map.index_select(1, select_indices)
self.model.decoder.map_state(
lambda state, dim: state.index_select(dim, select_indices))
results["scores"] = random_sampler.scores
results["predictions"] = random_sampler.predictions
results["attention"] = random_sampler.attention
return results
def translate_batch(self, batch, src_vocabs, attn_debug): def translate_batch(self, batch, src_vocabs, attn_debug):
"""Translate a batch of sentences.""" """Translate a batch of sentences."""
with torch.no_grad(): with torch.no_grad():
if self.beam_size == 1: if self.beam_size == 1:
return self._translate_random_sampling( decode_strategy = GreedySearch(
batch, pad=self._tgt_pad_idx,
src_vocabs, bos=self._tgt_bos_idx,
self.max_length, eos=self._tgt_eos_idx,
min_length=self.min_length, batch_size=batch.batch_size,
min_length=self.min_length, max_length=self.max_length,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
return_attention=attn_debug or self.replace_unk,
sampling_temp=self.random_sampling_temp, sampling_temp=self.random_sampling_temp,
keep_topk=self.sample_from_topk, keep_topk=self.sample_from_topk)
return_attention=attn_debug or self.replace_unk)
else: else:
return self._translate_batch( # TODO: support these blacklisted features
batch, assert not self.dump_beam
src_vocabs, decode_strategy = BeamSearch(
self.max_length, self.beam_size,
min_length=self.min_length, batch_size=batch.batch_size,
ratio=self.ratio, pad=self._tgt_pad_idx,
bos=self._tgt_bos_idx,
eos=self._tgt_eos_idx,
n_best=self.n_best, n_best=self.n_best,
return_attention=attn_debug or self.replace_unk) global_scorer=self.global_scorer,
min_length=self.min_length, max_length=self.max_length,
return_attention=attn_debug or self.replace_unk,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
stepwise_penalty=self.stepwise_penalty,
ratio=self.ratio)
return self._translate_batch_with_strategy(batch, src_vocabs,
decode_strategy)
def _run_encoder(self, batch): def _run_encoder(self, batch):
src, src_lengths = batch.src if isinstance(batch.src, tuple) \ src, src_lengths = batch.src if isinstance(batch.src, tuple) \
@ -577,7 +604,8 @@ class Translator(object):
src_map) src_map)
# here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab] # here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
if batch_offset is None: if batch_offset is None:
scores = scores.view(batch.batch_size, -1, scores.size(-1)) scores = scores.view(-1, batch.batch_size, scores.size(-1))
scores = scores.transpose(0, 1).contiguous()
else: else:
scores = scores.view(-1, self.beam_size, scores.size(-1)) scores = scores.view(-1, self.beam_size, scores.size(-1))
scores = collapse_copy_scores( scores = collapse_copy_scores(
@ -594,21 +622,25 @@ class Translator(object):
# or [ tgt_len, batch_size, vocab ] when full sentence # or [ tgt_len, batch_size, vocab ] when full sentence
return log_probs, attn return log_probs, attn
def _translate_batch( def _translate_batch_with_strategy(
self, self,
batch, batch,
src_vocabs, src_vocabs,
max_length, decode_strategy):
min_length=0, """Translate a batch of sentences step by step using cache.
ratio=0.,
n_best=1,
return_attention=False):
# TODO: support these blacklisted features.
assert not self.dump_beam
Args:
batch: a batch of sentences, yield by data iterator.
src_vocabs (list): list of torchtext.data.Vocab if can_copy.
decode_strategy (DecodeStrategy): A decode strategy to use for
generate translation step by step.
Returns:
results (dict): The translation results.
"""
# (0) Prep the components of the search. # (0) Prep the components of the search.
use_src_map = self.copy_attn use_src_map = self.copy_attn
beam_size = self.beam_size parallel_paths = decode_strategy.parallel_paths # beam_size
batch_size = batch.batch_size batch_size = batch.batch_size
# (1) Run the encoder on the src. # (1) Run the encoder on the src.
@ -624,42 +656,16 @@ class Translator(object):
batch, memory_bank, src_lengths, src_vocabs, use_src_map, batch, memory_bank, src_lengths, src_vocabs, use_src_map,
enc_states, batch_size, src)} enc_states, batch_size, src)}
# (2) Repeat src objects `beam_size` times. # (2) prep decode_strategy. Possibly repeat src objects.
# We use batch_size x beam_size src_map = batch.src_map if use_src_map else None
src_map = (tile(batch.src_map, beam_size, dim=1) fn_map_state, memory_bank, memory_lengths, src_map = \
if use_src_map else None) decode_strategy.initialize(memory_bank, src_lengths, src_map)
self.model.decoder.map_state( if fn_map_state is not None:
lambda state, dim: tile(state, beam_size, dim=dim)) self.model.decoder.map_state(fn_map_state)
if isinstance(memory_bank, tuple): # (3) Begin decoding step by step:
memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) for step in range(decode_strategy.max_length):
mb_device = memory_bank[0].device decoder_input = decode_strategy.current_predictions.view(1, -1, 1)
else:
memory_bank = tile(memory_bank, beam_size, dim=1)
mb_device = memory_bank.device
memory_lengths = tile(src_lengths, beam_size)
# (0) pt 2, prep the beam object
beam = BeamSearch(
beam_size,
n_best=n_best,
batch_size=batch_size,
global_scorer=self.global_scorer,
pad=self._tgt_pad_idx,
eos=self._tgt_eos_idx,
bos=self._tgt_bos_idx,
min_length=min_length,
ratio=ratio,
max_length=max_length,
mb_device=mb_device,
return_attention=return_attention,
stepwise_penalty=self.stepwise_penalty,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
memory_lengths=memory_lengths)
for step in range(max_length):
decoder_input = beam.current_predictions.view(1, -1, 1)
log_probs, attn = self._decode_and_generate( log_probs, attn = self._decode_and_generate(
decoder_input, decoder_input,
@ -669,18 +675,18 @@ class Translator(object):
memory_lengths=memory_lengths, memory_lengths=memory_lengths,
src_map=src_map, src_map=src_map,
step=step, step=step,
batch_offset=beam._batch_offset) batch_offset=decode_strategy.batch_offset)
beam.advance(log_probs, attn) decode_strategy.advance(log_probs, attn)
any_beam_is_finished = beam.is_finished.any() any_finished = decode_strategy.is_finished.any()
if any_beam_is_finished: if any_finished:
beam.update_finished() decode_strategy.update_finished()
if beam.done: if decode_strategy.done:
break break
select_indices = beam.current_origin select_indices = decode_strategy.select_indices
if any_beam_is_finished: if any_finished:
# Reorder states. # Reorder states.
if isinstance(memory_bank, tuple): if isinstance(memory_bank, tuple):
memory_bank = tuple(x.index_select(1, select_indices) memory_bank = tuple(x.index_select(1, select_indices)
@ -693,107 +699,18 @@ class Translator(object):
if src_map is not None: if src_map is not None:
src_map = src_map.index_select(1, select_indices) src_map = src_map.index_select(1, select_indices)
if parallel_paths > 1 or any_finished:
self.model.decoder.map_state( self.model.decoder.map_state(
lambda state, dim: state.index_select(dim, select_indices)) lambda state, dim: state.index_select(dim, select_indices))
results["scores"] = beam.scores results["scores"] = decode_strategy.scores
results["predictions"] = beam.predictions results["predictions"] = decode_strategy.predictions
results["attention"] = beam.attention results["attention"] = decode_strategy.attention
return results if self.report_align:
results["alignment"] = self._align_forward(
# This is left in the code for now, but unsued batch, decode_strategy.predictions)
def _translate_batch_deprecated(self, batch, src_vocabs):
# (0) Prep each of the components of the search.
# And helper method for reducing verbosity.
use_src_map = self.copy_attn
beam_size = self.beam_size
batch_size = batch.batch_size
beam = [onmt.translate.Beam(
beam_size,
n_best=self.n_best,
cuda=self.cuda,
global_scorer=self.global_scorer,
pad=self._tgt_pad_idx,
eos=self._tgt_eos_idx,
bos=self._tgt_bos_idx,
min_length=self.min_length,
stepwise_penalty=self.stepwise_penalty,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs)
for __ in range(batch_size)]
# (1) Run the encoder on the src.
src, enc_states, memory_bank, src_lengths = self._run_encoder(batch)
self.model.decoder.init_state(src, memory_bank, enc_states)
results = {
"predictions": [],
"scores": [],
"attention": [],
"batch": batch,
"gold_score": self._gold_score(
batch, memory_bank, src_lengths, src_vocabs, use_src_map,
enc_states, batch_size, src)}
# (2) Repeat src objects `beam_size` times.
# We use now batch_size x beam_size (same as fast mode)
src_map = (tile(batch.src_map, beam_size, dim=1)
if use_src_map else None)
self.model.decoder.map_state(
lambda state, dim: tile(state, beam_size, dim=dim))
if isinstance(memory_bank, tuple):
memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank)
else: else:
memory_bank = tile(memory_bank, beam_size, dim=1) results["alignment"] = [[] for _ in range(batch_size)]
memory_lengths = tile(src_lengths, beam_size)
# (3) run the decoder to generate sentences, using beam search.
for i in range(self.max_length):
if all((b.done for b in beam)):
break
# (a) Construct batch x beam_size nxt words.
# Get all the pending current beam words and arrange for forward.
inp = torch.stack([b.current_predictions for b in beam])
inp = inp.view(1, -1, 1)
# (b) Decode and forward
out, beam_attn = self._decode_and_generate(
inp, memory_bank, batch, src_vocabs,
memory_lengths=memory_lengths, src_map=src_map, step=i
)
out = out.view(batch_size, beam_size, -1)
beam_attn = beam_attn.view(batch_size, beam_size, -1)
# (c) Advance each beam.
select_indices_array = []
# Loop over the batch_size number of beam
for j, b in enumerate(beam):
if not b.done:
b.advance(out[j, :],
beam_attn.data[j, :, :memory_lengths[j]])
select_indices_array.append(
b.current_origin + j * beam_size)
select_indices = torch.cat(select_indices_array)
self.model.decoder.map_state(
lambda state, dim: state.index_select(dim, select_indices))
# (4) Extract sentences from beam.
for b in beam:
scores, ks = b.sort_finished(minimum=self.n_best)
hyps, attn = [], []
for times, k in ks[:self.n_best]:
hyp, att = b.get_hyp(times, k)
hyps.append(hyp)
attn.append(att)
results["predictions"].append(hyps)
results["scores"].append(scores)
results["attention"].append(attn)
return results return results
def _score_target(self, batch, memory_bank, src_lengths, def _score_target(self, batch, memory_bank, src_lengths,
@ -816,31 +733,9 @@ class Translator(object):
if words_total == 0: if words_total == 0:
msg = "%s No words predicted" % (name,) msg = "%s No words predicted" % (name,)
else: else:
avg_score = score_total / words_total
ppl = np.exp(-score_total.item() / words_total)
msg = ("%s AVG SCORE: %.4f, %s PPL: %.4f" % ( msg = ("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
name, score_total / words_total, name, avg_score,
name, math.exp(-score_total / words_total))) name, ppl))
return msg
def _report_bleu(self, tgt_path):
import subprocess
base_dir = os.path.abspath(__file__ + "/../../..")
# Rollback pointer to the beginning.
self.out_file.seek(0)
print()
res = subprocess.check_output(
"perl %s/tools/multi-bleu.perl %s" % (base_dir, tgt_path),
stdin=self.out_file, shell=True
).decode("utf-8")
msg = ">> " + res.strip()
return msg
def _report_rouge(self, tgt_path):
import subprocess
path = os.path.split(os.path.realpath(__file__))[0]
msg = subprocess.check_output(
"python %s/tools/test_rouge.py -r %s -c STDIN" % (path, tgt_path),
shell=True, stdin=self.out_file
).decode("utf-8").strip()
return msg return msg

Просмотреть файл

@ -1,5 +1,6 @@
"""Module defining various utilities.""" """Module defining various utilities."""
from onmt.utils.misc import split_corpus, aeq, use_gpu, set_random_seed from onmt.utils.misc import split_corpus, aeq, use_gpu, set_random_seed
from onmt.utils.alignment import make_batch_align_matrix
from onmt.utils.report_manager import ReportMgr, build_report_manager from onmt.utils.report_manager import ReportMgr, build_report_manager
from onmt.utils.statistics import Statistics from onmt.utils.statistics import Statistics
from onmt.utils.optimizers import MultipleOptimizer, \ from onmt.utils.optimizers import MultipleOptimizer, \
@ -9,4 +10,4 @@ from onmt.utils.earlystopping import EarlyStopping, scorers_from_opts
__all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr", __all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr",
"build_report_manager", "Statistics", "build_report_manager", "Statistics",
"MultipleOptimizer", "Optimizer", "AdaFactor", "EarlyStopping", "MultipleOptimizer", "Optimizer", "AdaFactor", "EarlyStopping",
"scorers_from_opts"] "scorers_from_opts", "make_batch_align_matrix"]

Просмотреть файл

@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
import torch
from itertools import accumulate
def make_batch_align_matrix(index_tensor, size=None, normalize=False):
"""
Convert a sparse index_tensor into a batch of alignment matrix,
with row normalize to the sum of 1 if set normalize.
Args:
index_tensor (LongTensor): ``(N, 3)`` of [batch_id, tgt_id, src_id]
size (List[int]): Size of the sparse tensor.
normalize (bool): if normalize the 2nd dim of resulting tensor.
"""
n_fill, device = index_tensor.size(0), index_tensor.device
value_tensor = torch.ones([n_fill], dtype=torch.float)
dense_tensor = torch.sparse_coo_tensor(
index_tensor.t(), value_tensor, size=size, device=device).to_dense()
if normalize:
row_sum = dense_tensor.sum(-1, keepdim=True) # sum by row(tgt)
# threshold on 1 to avoid div by 0
torch.nn.functional.threshold(row_sum, 1, 1, inplace=True)
dense_tensor.div_(row_sum)
return dense_tensor
def extract_alignment(align_matrix, tgt_mask, src_lens, n_best):
"""
Extract a batched align_matrix into its src indice alignment lists,
with tgt_mask to filter out invalid tgt position as EOS/PAD.
BOS already excluded from tgt_mask in order to match prediction.
Args:
align_matrix (Tensor): ``(B, tgt_len, src_len)``,
attention head normalized by Softmax(dim=-1)
tgt_mask (BoolTensor): ``(B, tgt_len)``, True for EOS, PAD.
src_lens (LongTensor): ``(B,)``, containing valid src length
n_best (int): a value indicating number of parallel translation.
* B: denote flattened batch as B = batch_size * n_best.
Returns:
alignments (List[List[FloatTensor|None]]): ``(batch_size, n_best,)``,
containing valid alignment matrix (or None if blank prediction)
for each translation.
"""
batch_size_n_best = align_matrix.size(0)
assert batch_size_n_best % n_best == 0
alignments = [[] for _ in range(batch_size_n_best // n_best)]
# treat alignment matrix one by one as each have different lengths
for i, (am_b, tgt_mask_b, src_len) in enumerate(
zip(align_matrix, tgt_mask, src_lens)):
valid_tgt = ~tgt_mask_b
valid_tgt_len = valid_tgt.sum()
if valid_tgt_len == 0:
# No alignment if not exist valid tgt token
valid_alignment = None
else:
# get valid alignment (sub-matrix from full paded aligment matrix)
am_valid_tgt = am_b.masked_select(valid_tgt.unsqueeze(-1)) \
.view(valid_tgt_len, -1)
valid_alignment = am_valid_tgt[:, :src_len] # only keep valid src
alignments[i // n_best].append(valid_alignment)
return alignments
def build_align_pharaoh(valid_alignment):
"""Convert valid alignment matrix to i-j (from 0) Pharaoh format pairs,
or empty list if it's None.
"""
align_pairs = []
if isinstance(valid_alignment, torch.Tensor):
tgt_align_src_id = valid_alignment.argmax(dim=-1)
for tgt_id, src_id in enumerate(tgt_align_src_id.tolist()):
align_pairs.append(str(src_id) + "-" + str(tgt_id))
align_pairs.sort(key=lambda x: int(x.split('-')[-1])) # sort by tgt_id
align_pairs.sort(key=lambda x: int(x.split('-')[0])) # sort by src_id
return align_pairs
def to_word_align(src, tgt, subword_align, mode):
"""Convert subword alignment to word alignment.
Args:
src (string): tokenized sentence in source language.
tgt (string): tokenized sentence in target language.
subword_align (string): align_pharaoh correspond to src-tgt.
mode (string): tokenization mode used by src and tgt,
choose from ["joiner", "spacer"].
Returns:
word_align (string): converted alignments correspand to
detokenized src-tgt.
"""
src, tgt = src.strip().split(), tgt.strip().split()
subword_align = {(int(a), int(b)) for a, b in (x.split("-")
for x in subword_align.split())}
if mode == 'joiner':
src_map = subword_map_by_joiner(src, marker='')
tgt_map = subword_map_by_joiner(tgt, marker='')
elif mode == 'spacer':
src_map = subword_map_by_spacer(src, marker='')
tgt_map = subword_map_by_spacer(tgt, marker='')
else:
raise ValueError("Invalid value for argument mode!")
word_align = list({"{}-{}".format(src_map[a], tgt_map[b])
for a, b in subword_align})
word_align.sort(key=lambda x: int(x.split('-')[-1])) # sort by tgt_id
word_align.sort(key=lambda x: int(x.split('-')[0])) # sort by src_id
return " ".join(word_align)
def subword_map_by_joiner(subwords, marker=''):
"""Return word id for each subword token (annotate by joiner)."""
flags = [0] * len(subwords)
for i, tok in enumerate(subwords):
if tok.endswith(marker):
flags[i] = 1
if tok.startswith(marker):
assert i >= 1 and flags[i-1] != 1, \
"Sentence `{}` not correct!".format(" ".join(subwords))
flags[i-1] = 1
marker_acc = list(accumulate([0] + flags[:-1]))
word_group = [(i - maker_sofar) for i, maker_sofar
in enumerate(marker_acc)]
return word_group
def subword_map_by_spacer(subwords, marker=''):
"""Return word id for each subword token (annotate by spacer)."""
word_group = list(accumulate([int(marker in x) for x in subwords]))
if word_group[0] == 1: # when dummy prefix is set
word_group = [item - 1 for item in word_group]
return word_group

Просмотреть файл

@ -2,11 +2,11 @@
from __future__ import absolute_import from __future__ import absolute_import
import logging import logging
from logging.handlers import RotatingFileHandler
logger = logging.getLogger() logger = logging.getLogger()
def init_logger(log_file=None, log_file_level=logging.NOTSET): def init_logger(log_file=None, log_file_level=logging.NOTSET, rotate=False):
log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s")
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -16,6 +16,10 @@ def init_logger(log_file=None, log_file_level=logging.NOTSET):
logger.handlers = [console_handler] logger.handlers = [console_handler]
if log_file and log_file != '': if log_file and log_file != '':
if rotate:
file_handler = RotatingFileHandler(
log_file, maxBytes=1000000, backupCount=10)
else:
file_handler = logging.FileHandler(log_file) file_handler = logging.FileHandler(log_file)
file_handler.setLevel(log_file_level) file_handler.setLevel(log_file_level)
file_handler.setFormatter(log_format) file_handler.setFormatter(log_format)

Просмотреть файл

@ -57,7 +57,8 @@ def build_loss_compute(model, tgt_field, opt, train=True):
) )
else: else:
compute = NMTLossCompute( compute = NMTLossCompute(
criterion, loss_gen, lambda_coverage=opt.lambda_coverage) criterion, loss_gen, lambda_coverage=opt.lambda_coverage,
lambda_align=opt.lambda_align)
compute.to(device) compute.to(device)
return compute return compute
@ -226,9 +227,10 @@ class NMTLossCompute(LossComputeBase):
""" """
def __init__(self, criterion, generator, normalization="sents", def __init__(self, criterion, generator, normalization="sents",
lambda_coverage=0.0): lambda_coverage=0.0, lambda_align=0.0):
super(NMTLossCompute, self).__init__(criterion, generator) super(NMTLossCompute, self).__init__(criterion, generator)
self.lambda_coverage = lambda_coverage self.lambda_coverage = lambda_coverage
self.lambda_align = lambda_align
def _make_shard_state(self, batch, output, range_, attns=None): def _make_shard_state(self, batch, output, range_, attns=None):
shard_state = { shard_state = {
@ -248,10 +250,33 @@ class NMTLossCompute(LossComputeBase):
"std_attn": attns.get("std"), "std_attn": attns.get("std"),
"coverage_attn": coverage "coverage_attn": coverage
}) })
if self.lambda_align != 0.0:
# attn_align should be in (batch_size, pad_tgt_size, pad_src_size)
attn_align = attns.get("align", None)
# align_idx should be a Tensor in size([N, 3]), N is total number
# of align src-tgt pair in current batch, each as
# ['sent_N°_in_batch', 'tgt_id+1', 'src_id'] (check AlignField)
align_idx = batch.align
assert attns is not None
assert attn_align is not None, "lambda_align != 0.0 requires " \
"alignement attention head"
assert align_idx is not None, "lambda_align != 0.0 requires " \
"provide guided alignement"
pad_tgt_size, batch_size, _ = batch.tgt.size()
pad_src_size = batch.src[0].size(0)
align_matrix_size = [batch_size, pad_tgt_size, pad_src_size]
ref_align = onmt.utils.make_batch_align_matrix(
align_idx, align_matrix_size, normalize=True)
# NOTE: tgt-src ref alignement that in range_ of shard
# (coherent with batch.tgt)
shard_state.update({
"align_head": attn_align,
"ref_align": ref_align[:, range_[0] + 1: range_[1], :]
})
return shard_state return shard_state
def _compute_loss(self, batch, output, target, std_attn=None, def _compute_loss(self, batch, output, target, std_attn=None,
coverage_attn=None): coverage_attn=None, align_head=None, ref_align=None):
bottled_output = self._bottle(output) bottled_output = self._bottle(output)
@ -263,15 +288,33 @@ class NMTLossCompute(LossComputeBase):
coverage_loss = self._compute_coverage_loss( coverage_loss = self._compute_coverage_loss(
std_attn=std_attn, coverage_attn=coverage_attn) std_attn=std_attn, coverage_attn=coverage_attn)
loss += coverage_loss loss += coverage_loss
if self.lambda_align != 0.0:
if align_head.dtype != loss.dtype: # Fix FP16
align_head = align_head.to(loss.dtype)
if ref_align.dtype != loss.dtype:
ref_align = ref_align.to(loss.dtype)
align_loss = self._compute_alignement_loss(
align_head=align_head, ref_align=ref_align)
loss += align_loss
stats = self._stats(loss.clone(), scores, gtruth) stats = self._stats(loss.clone(), scores, gtruth)
return loss, stats return loss, stats
def _compute_coverage_loss(self, std_attn, coverage_attn): def _compute_coverage_loss(self, std_attn, coverage_attn):
covloss = torch.min(std_attn, coverage_attn).sum(2).view(-1) covloss = torch.min(std_attn, coverage_attn).sum()
covloss *= self.lambda_coverage covloss *= self.lambda_coverage
return covloss return covloss
def _compute_alignement_loss(self, align_head, ref_align):
"""Compute loss between 2 partial alignment matrix."""
# align_head contains value in [0, 1) presenting attn prob,
# 0 was resulted by the context attention src_pad_mask
# So, the correspand position in ref_align should also be 0
# Therefore, clip align_head to > 1e-18 should be bias free.
align_loss = -align_head.clamp(min=1e-18).log().mul(ref_align).sum()
align_loss *= self.lambda_align
return align_loss
def filter_shard_state(state, shard_size=None): def filter_shard_state(state, shard_size=None):
for k, v in state.items(): for k, v in state.items():

Просмотреть файл

@ -3,10 +3,23 @@
import torch import torch
import random import random
import inspect import inspect
from itertools import islice from itertools import islice, repeat
import os
def split_corpus(path, shard_size): def split_corpus(path, shard_size, default=None):
"""yield a `list` containing `shard_size` line of `path`,
or repeatly generate `default` if `path` is None.
"""
if path is not None:
return _split_corpus(path, shard_size)
else:
return repeat(default)
def _split_corpus(path, shard_size):
"""Yield a `list` containing `shard_size` line of `path`.
"""
with open(path, "rb") as f: with open(path, "rb") as f:
if shard_size <= 0: if shard_size <= 0:
yield f.readlines() yield f.readlines()
@ -124,3 +137,37 @@ def relative_matmul(x, z, transpose):
def fn_args(fun): def fn_args(fun):
"""Returns the list of function arguments name.""" """Returns the list of function arguments name."""
return inspect.getfullargspec(fun).args return inspect.getfullargspec(fun).args
def report_matrix(row_label, column_label, matrix):
header_format = "{:>10.10} " + "{:>10.7} " * len(row_label)
row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label)
output = header_format.format("", *row_label) + '\n'
for word, row in zip(column_label, matrix):
max_index = row.index(max(row))
row_format = row_format.replace(
"{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
row_format = row_format.replace(
"{:*>10.7f} ", "{:>10.7f} ", max_index)
output += row_format.format(word, *row) + '\n'
row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label)
return output
def check_model_config(model_config, root):
# we need to check the model path + any tokenizer path
for model in model_config["models"]:
model_path = os.path.join(root, model)
if not os.path.exists(model_path):
raise FileNotFoundError(
"{} from model {} does not exist".format(
model_path, model_config["id"]))
if "tokenizer" in model_config.keys():
if "params" in model_config["tokenizer"].keys():
for k, v in model_config["tokenizer"]["params"].items():
if k.endswith("path"):
tok_path = os.path.join(root, v)
if not os.path.exists(tok_path):
raise FileNotFoundError(
"{} from model {} does not exist".format(
tok_path, model_config["id"]))

Просмотреть файл

@ -6,6 +6,9 @@ import operator
import functools import functools
from copy import copy from copy import copy
from math import sqrt from math import sqrt
import types
import importlib
from onmt.utils.misc import fn_args
def build_torch_optimizer(model, opt): def build_torch_optimizer(model, opt):
@ -75,8 +78,8 @@ def build_torch_optimizer(model, opt):
betas=betas, betas=betas,
eps=1e-8)]) eps=1e-8)])
elif opt.optim == 'fusedadam': elif opt.optim == 'fusedadam':
import apex # we use here a FusedAdam() copy of an old Apex repo
optimizer = apex.optimizers.FusedAdam( optimizer = FusedAdam(
params, params,
lr=opt.learning_rate, lr=opt.learning_rate,
betas=betas) betas=betas)
@ -85,14 +88,23 @@ def build_torch_optimizer(model, opt):
if opt.model_dtype == 'fp16': if opt.model_dtype == 'fp16':
import apex import apex
if opt.optim != 'fusedadam':
# In this case use the new AMP API from apex
loss_scale = "dynamic" if opt.loss_scale == 0 else opt.loss_scale loss_scale = "dynamic" if opt.loss_scale == 0 else opt.loss_scale
model, optimizer = apex.amp.initialize( model, optimizer = apex.amp.initialize(
[model, model.generator], [model, model.generator],
optimizer, optimizer,
opt_level=opt.apex_opt_level, opt_level=opt.apex_opt_level,
loss_scale=loss_scale, loss_scale=loss_scale,
keep_batchnorm_fp32=False if opt.optim == "fusedadam" else None) keep_batchnorm_fp32=None)
else:
# In this case use the old FusedAdam with FP16_optimizer wrapper
static_loss_scale = opt.loss_scale
dynamic_loss_scale = opt.loss_scale == 0
optimizer = apex.contrib.optimizers.FP16_Optimizer(
optimizer,
static_loss_scale=static_loss_scale,
dynamic_loss_scale=dynamic_loss_scale)
return optimizer return optimizer
@ -222,8 +234,7 @@ class Optimizer(object):
self._max_grad_norm = max_grad_norm or 0 self._max_grad_norm = max_grad_norm or 0
self._training_step = 1 self._training_step = 1
self._decay_step = 1 self._decay_step = 1
self._with_fp16_wrapper = ( self._fp16 = None
optimizer.__class__.__name__ == "FP16_Optimizer")
@classmethod @classmethod
def from_opt(cls, model, opt, checkpoint=None): def from_opt(cls, model, opt, checkpoint=None):
@ -273,6 +284,11 @@ class Optimizer(object):
optim_opt.learning_rate, optim_opt.learning_rate,
learning_rate_decay_fn=make_learning_rate_decay_fn(optim_opt), learning_rate_decay_fn=make_learning_rate_decay_fn(optim_opt),
max_grad_norm=optim_opt.max_grad_norm) max_grad_norm=optim_opt.max_grad_norm)
if opt.model_dtype == "fp16":
if opt.optim == "fusedadam":
optimizer._fp16 = "legacy"
else:
optimizer._fp16 = "amp"
if optim_state_dict: if optim_state_dict:
optimizer.load_state_dict(optim_state_dict) optimizer.load_state_dict(optim_state_dict)
return optimizer return optimizer
@ -311,10 +327,15 @@ class Optimizer(object):
def backward(self, loss): def backward(self, loss):
"""Wrapper for backward pass. Some optimizer requires ownership of the """Wrapper for backward pass. Some optimizer requires ownership of the
backward pass.""" backward pass."""
if self._with_fp16_wrapper: if self._fp16 == "amp":
import apex import apex
with apex.amp.scale_loss(loss, self._optimizer) as scaled_loss: with apex.amp.scale_loss(loss, self._optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
elif self._fp16 == "legacy":
kwargs = {}
if "update_master_grads" in fn_args(self._optimizer.backward):
kwargs["update_master_grads"] = True
self._optimizer.backward(loss, **kwargs)
else: else:
loss.backward() loss.backward()
@ -325,17 +346,16 @@ class Optimizer(object):
rate. rate.
""" """
learning_rate = self.learning_rate() learning_rate = self.learning_rate()
if self._with_fp16_wrapper: if self._fp16 == "legacy":
if hasattr(self._optimizer, "update_master_grads"): if hasattr(self._optimizer, "update_master_grads"):
self._optimizer.update_master_grads() self._optimizer.update_master_grads()
if hasattr(self._optimizer, "clip_master_grads") and \ if hasattr(self._optimizer, "clip_master_grads") and \
self._max_grad_norm > 0: self._max_grad_norm > 0:
import apex self._optimizer.clip_master_grads(self._max_grad_norm)
torch.nn.utils.clip_grad_norm_(
apex.amp.master_params(self), self._max_grad_norm)
for group in self._optimizer.param_groups: for group in self._optimizer.param_groups:
group['lr'] = learning_rate group['lr'] = learning_rate
if not self._with_fp16_wrapper and self._max_grad_norm > 0: if self._fp16 is None and self._max_grad_norm > 0:
clip_grad_norm_(group['params'], self._max_grad_norm) clip_grad_norm_(group['params'], self._max_grad_norm)
self._optimizer.step() self._optimizer.step()
self._decay_step += 1 self._decay_step += 1
@ -513,3 +533,156 @@ class AdaFactor(torch.optim.Optimizer):
p.data.add_(-group['weight_decay'] * lr_t, p.data) p.data.add_(-group['weight_decay'] * lr_t, p.data)
return loss return loss
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only.
Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square.
(default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
.. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params,
lr=1e-3, bias_correction=True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt=False,
weight_decay=0., max_grad_norm=0., amsgrad=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
if amsgrad:
raise RuntimeError('AMSGrad variant not supported.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(FusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
def step(self, closure=None, grads=None, output_params=None,
scale=1., grad_norms=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
output params (list of tensors, optional): A reduced precision copy
of the updated weights written out in addition to the regular
updated weights. Have to be of same type as gradients.
(default: None)
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
"""
loss = None
if closure is not None:
loss = closure()
if grads is None:
grads_group = [None]*len(self.param_groups)
# backward compatibility
# assuming a list/generator of parameter means single group
elif isinstance(grads, types.GeneratorType):
grads_group = [grads]
elif type(grads[0]) != list:
grads_group = [grads]
else:
grads_group = grads
if output_params is None:
output_params_group = [None]*len(self.param_groups)
elif isinstance(output_params, types.GeneratorType):
output_params_group = [output_params]
elif type(output_params[0]) != list:
output_params_group = [output_params]
else:
output_params_group = output_params
if grad_norms is None:
grad_norms = [None]*len(self.param_groups)
for group, grads_this_group, output_params_this_group, \
grad_norm in zip(self.param_groups, grads_group,
output_params_group, grad_norms):
if grads_this_group is None:
grads_this_group = [None]*len(group['params'])
if output_params_this_group is None:
output_params_this_group = [None]*len(group['params'])
# compute combined scale factor for this group
combined_scale = scale
if group['max_grad_norm'] > 0:
# norm is in fact norm*scale
clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
if clip > 1:
combined_scale = clip * scale
bias_correction = 1 if group['bias_correction'] else 0
for p, grad, output_param in zip(group['params'],
grads_this_group,
output_params_this_group):
# note: p.grad should not ever be set for correct operation of
# mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None:
continue
if grad is None:
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('FusedAdam does not support sparse \
gradients, please consider \
SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
out_p = torch.tensor([], dtype=torch.float) if output_param \
is None else output_param
fused_adam_cuda.adam(p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
return loss

Просмотреть файл

@ -46,6 +46,11 @@ class ArgumentParser(cfargparse.ArgumentParser):
if model_opt.copy_attn_type is None: if model_opt.copy_attn_type is None:
model_opt.copy_attn_type = model_opt.global_attention model_opt.copy_attn_type = model_opt.global_attention
if model_opt.alignment_layer is None:
model_opt.alignment_layer = -2
model_opt.lambda_align = 0.0
model_opt.full_context_alignment = False
@classmethod @classmethod
def validate_model_opts(cls, model_opt): def validate_model_opts(cls, model_opt):
assert model_opt.model_type in ["text", "img", "audio", "vec"], \ assert model_opt.model_type in ["text", "img", "audio", "vec"], \
@ -63,6 +68,17 @@ class ArgumentParser(cfargparse.ArgumentParser):
if model_opt.model_type != "text": if model_opt.model_type != "text":
raise AssertionError( raise AssertionError(
"--share_embeddings requires --model_type text.") "--share_embeddings requires --model_type text.")
if model_opt.lambda_align > 0.0:
assert model_opt.decoder_type == 'transformer', \
"Only transformer is supported to joint learn alignment."
assert model_opt.alignment_layer < model_opt.dec_layers and \
model_opt.alignment_layer >= -model_opt.dec_layers, \
"N° alignment_layer should be smaller than number of layers."
logger.info("Joint learn alignment at layer [{}] "
"with {} heads in full_context '{}'.".format(
model_opt.alignment_layer,
model_opt.alignment_heads,
model_opt.full_context_alignment))
@classmethod @classmethod
def ckpt_model_opts(cls, ckpt_opt): def ckpt_model_opts(cls, ckpt_opt):
@ -85,8 +101,7 @@ class ArgumentParser(cfargparse.ArgumentParser):
raise AssertionError( raise AssertionError(
"gpuid is deprecated see world_size and gpu_ranks") "gpuid is deprecated see world_size and gpu_ranks")
if torch.cuda.is_available() and not opt.gpu_ranks: if torch.cuda.is_available() and not opt.gpu_ranks:
logger.info("WARNING: You have a CUDA device, \ logger.warn("You have a CUDA device, should run with -gpu_ranks")
should run with -gpu_ranks")
if opt.world_size < len(opt.gpu_ranks): if opt.world_size < len(opt.gpu_ranks):
raise AssertionError( raise AssertionError(
"parameter counts of -gpu_ranks must be less or equal " "parameter counts of -gpu_ranks must be less or equal "
@ -128,6 +143,18 @@ class ArgumentParser(cfargparse.ArgumentParser):
for file in opt.train_src + opt.train_tgt: for file in opt.train_src + opt.train_tgt:
assert os.path.isfile(file), "Please check path of %s" % file assert os.path.isfile(file), "Please check path of %s" % file
if len(opt.train_align) == 1 and opt.train_align[0] is None:
opt.train_align = [None] * len(opt.train_src)
else:
assert len(opt.train_align) == len(opt.train_src), \
"Please provide same number of word alignment train \
files as src/tgt!"
for file in opt.train_align:
assert os.path.isfile(file), "Please check path of %s" % file
assert not opt.valid_align or os.path.isfile(opt.valid_align), \
"Please check path of your valid alignment file!"
assert not opt.valid_src or os.path.isfile(opt.valid_src), \ assert not opt.valid_src or os.path.isfile(opt.valid_src), \
"Please check path of your valid src file!" "Please check path of your valid src file!"
assert not opt.valid_tgt or os.path.isfile(opt.valid_tgt), \ assert not opt.valid_tgt or os.path.isfile(opt.valid_tgt), \

Просмотреть файл

@ -8,16 +8,15 @@ import onmt
from onmt.utils.logging import logger from onmt.utils.logging import logger
def build_report_manager(opt): def build_report_manager(opt, gpu_rank):
if opt.tensorboard: if opt.tensorboard and gpu_rank == 0:
from tensorboardX import SummaryWriter from torch.utils.tensorboard import SummaryWriter
tensorboard_log_dir = opt.tensorboard_log_dir tensorboard_log_dir = opt.tensorboard_log_dir
if not opt.train_from: if not opt.train_from:
tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S") tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S")
writer = SummaryWriter(tensorboard_log_dir, writer = SummaryWriter(tensorboard_log_dir, comment="Unmt")
comment="Unmt")
else: else:
writer = None writer = None
@ -42,7 +41,6 @@ class ReportMgrBase(object):
means that you will need to set it later or use `start()` means that you will need to set it later or use `start()`
""" """
self.report_every = report_every self.report_every = report_every
self.progress_step = 0
self.start_time = start_time self.start_time = start_time
def start(self): def start(self):
@ -75,7 +73,6 @@ class ReportMgrBase(object):
onmt.utils.Statistics.all_gather_stats(report_stats) onmt.utils.Statistics.all_gather_stats(report_stats)
self._report_training( self._report_training(
step, num_steps, learning_rate, report_stats) step, num_steps, learning_rate, report_stats)
self.progress_step += 1
return onmt.utils.Statistics() return onmt.utils.Statistics()
else: else:
return report_stats return report_stats
@ -127,11 +124,10 @@ class ReportMgr(ReportMgrBase):
report_stats.output(step, num_steps, report_stats.output(step, num_steps,
learning_rate, self.start_time) learning_rate, self.start_time)
# Log the progress using the number of batches on the x-axis.
self.maybe_log_tensorboard(report_stats, self.maybe_log_tensorboard(report_stats,
"progress", "progress",
learning_rate, learning_rate,
self.progress_step) step)
report_stats = onmt.utils.Statistics() report_stats = onmt.utils.Statistics()
return report_stats return report_stats

18
opennmt/prepare_opts.py Normal file
Просмотреть файл

@ -0,0 +1,18 @@
import onmt.opts as opts
from onmt.utils.parse import ArgumentParser
import os
import pickle
dir_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
def main():
parser = ArgumentParser()
opts.config_opts(parser)
opts.model_opts(parser)
opts.global_opts(parser)
opt = parser.parse_args()
with open(os.path.join(dir_path, 'opt_data'), 'wb') as f:
pickle.dump(opt, f)
if __name__ == "__main__":
main()

Просмотреть файл

@ -1,220 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- from onmt.bin.preprocess import main
"""
Pre-process Data / features files and build vocabulary
"""
import codecs
import glob
import sys
import gc
import torch
from functools import partial
from collections import Counter, defaultdict
from onmt.utils.logging import init_logger, logger
from onmt.utils.misc import split_corpus
import onmt.inputters as inputters
import onmt.opts as opts
from onmt.utils.parse import ArgumentParser
from onmt.inputters.inputter import _build_fields_vocab,\
_load_vocab
def check_existing_pt_files(opt):
""" Check if there are existing .pt files to avoid overwriting them """
pattern = opt.save_data + '.{}*.pt'
for t in ['train', 'valid']:
path = pattern.format(t)
if glob.glob(path):
sys.stderr.write("Please backup existing pt files: %s, "
"to avoid overwriting them!\n" % path)
sys.exit(1)
def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt):
assert corpus_type in ['train', 'valid']
if corpus_type == 'train':
counters = defaultdict(Counter)
srcs = opt.train_src
tgts = opt.train_tgt
ids = opt.train_ids
else:
srcs = [opt.valid_src]
tgts = [opt.valid_tgt]
ids = [None]
for src, tgt, maybe_id in zip(srcs, tgts, ids):
logger.info("Reading source and target files: %s %s." % (src, tgt))
src_shards = split_corpus(src, opt.shard_size)
tgt_shards = split_corpus(tgt, opt.shard_size)
shard_pairs = zip(src_shards, tgt_shards)
dataset_paths = []
if (corpus_type == "train" or opt.filter_valid) and tgt is not None:
filter_pred = partial(
inputters.filter_example, use_src_len=opt.data_type == "text",
max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length)
else:
filter_pred = None
if corpus_type == "train":
existing_fields = None
if opt.src_vocab != "":
try:
logger.info("Using existing vocabulary...")
existing_fields = torch.load(opt.src_vocab)
except torch.serialization.pickle.UnpicklingError:
logger.info("Building vocab from text file...")
src_vocab, src_vocab_size = _load_vocab(
opt.src_vocab, "src", counters,
opt.src_words_min_frequency)
else:
src_vocab = None
if opt.tgt_vocab != "":
tgt_vocab, tgt_vocab_size = _load_vocab(
opt.tgt_vocab, "tgt", counters,
opt.tgt_words_min_frequency)
else:
tgt_vocab = None
for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
assert len(src_shard) == len(tgt_shard)
logger.info("Building shard %d." % i)
dataset = inputters.Dataset(
fields,
readers=([src_reader, tgt_reader]
if tgt_reader else [src_reader]),
data=([("src", src_shard), ("tgt", tgt_shard)]
if tgt_reader else [("src", src_shard)]),
dirs=([opt.src_dir, None]
if tgt_reader else [opt.src_dir]),
sort_key=inputters.str2sortkey[opt.data_type],
filter_pred=filter_pred
)
if corpus_type == "train" and existing_fields is None:
for ex in dataset.examples:
for name, field in fields.items():
try:
f_iter = iter(field)
except TypeError:
f_iter = [(name, field)]
all_data = [getattr(ex, name, None)]
else:
all_data = getattr(ex, name)
for (sub_n, sub_f), fd in zip(
f_iter, all_data):
has_vocab = (sub_n == 'src' and
src_vocab is not None) or \
(sub_n == 'tgt' and
tgt_vocab is not None)
if (hasattr(sub_f, 'sequential')
and sub_f.sequential and not has_vocab):
val = fd
counters[sub_n].update(val)
if maybe_id:
shard_base = corpus_type + "_" + maybe_id
else:
shard_base = corpus_type
data_path = "{:s}.{:s}.{:d}.pt".\
format(opt.save_data, shard_base, i)
dataset_paths.append(data_path)
logger.info(" * saving %sth %s data shard to %s."
% (i, shard_base, data_path))
dataset.save(data_path)
del dataset.examples
gc.collect()
del dataset
gc.collect()
if corpus_type == "train":
vocab_path = opt.save_data + '.vocab.pt'
if existing_fields is None:
fields = _build_fields_vocab(
fields, counters, opt.data_type,
opt.share_vocab, opt.vocab_size_multiple,
opt.src_vocab_size, opt.src_words_min_frequency,
opt.tgt_vocab_size, opt.tgt_words_min_frequency)
else:
fields = existing_fields
torch.save(fields, vocab_path)
def build_save_vocab(train_dataset, fields, opt):
fields = inputters.build_vocab(
train_dataset, fields, opt.data_type, opt.share_vocab,
opt.src_vocab, opt.src_vocab_size, opt.src_words_min_frequency,
opt.tgt_vocab, opt.tgt_vocab_size, opt.tgt_words_min_frequency,
vocab_size_multiple=opt.vocab_size_multiple
)
vocab_path = opt.save_data + '.vocab.pt'
torch.save(fields, vocab_path)
def count_features(path):
"""
path: location of a corpus file with whitespace-delimited tokens and
-delimited features within the token
returns: the number of features in the dataset
"""
with codecs.open(path, "r", "utf-8") as f:
first_tok = f.readline().split(None, 1)[0]
return len(first_tok.split(u"")) - 1
def main(opt):
ArgumentParser.validate_preprocess_args(opt)
torch.manual_seed(opt.seed)
if not(opt.overwrite):
check_existing_pt_files(opt)
init_logger(opt.log_file)
logger.info("Extracting features...")
src_nfeats = 0
tgt_nfeats = 0
for src, tgt in zip(opt.train_src, opt.train_tgt):
src_nfeats += count_features(src) if opt.data_type == 'text' \
else 0
tgt_nfeats += count_features(tgt) # tgt always text so far
logger.info(" * number of source features: %d." % src_nfeats)
logger.info(" * number of target features: %d." % tgt_nfeats)
logger.info("Building `Fields` object...")
fields = inputters.get_fields(
opt.data_type,
src_nfeats,
tgt_nfeats,
dynamic_dict=opt.dynamic_dict,
src_truncate=opt.src_seq_length_trunc,
tgt_truncate=opt.tgt_seq_length_trunc)
src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
tgt_reader = inputters.str2reader["text"].from_opt(opt)
logger.info("Building & saving training data...")
build_save_dataset(
'train', fields, src_reader, tgt_reader, opt)
if opt.valid_src and opt.valid_tgt:
logger.info("Building & saving validation data...")
build_save_dataset('valid', fields, src_reader, tgt_reader, opt)
def _get_parser():
parser = ArgumentParser(description='preprocess.py')
opts.config_opts(parser)
opts.preprocess_opts(parser)
return parser
if __name__ == "__main__": if __name__ == "__main__":
parser = _get_parser() main()
opt = parser.parse_args()
main(opt)

Просмотреть файл

@ -1,11 +1,10 @@
cffi cffi
torchvision==0.2.1 torchvision
joblib joblib
librosa librosa
Pillow Pillow
git+git://github.com/pytorch/audio.git@d92de5b97fc6204db4b1e3ed20c03ac06f5d53f0 git+git://github.com/pytorch/audio.git@d92de5b97fc6204db4b1e3ed20c03ac06f5d53f0
pyrouge pyrouge
pyonmttok
opencv-python opencv-python
git+https://github.com/NVIDIA/apex git+https://github.com/NVIDIA/apex
flask pretrainedmodels

Просмотреть файл

@ -1,6 +0,0 @@
six
tqdm==4.30.*
torch>=1.1
git+https://github.com/pytorch/text.git@master#wheel=torchtext
future
configargparse

Просмотреть файл

@ -1,129 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
import configargparse from onmt.bin.server import main
from flask import Flask, jsonify, request
from onmt.translate import TranslationServer, ServerModelError
STATUS_OK = "ok"
STATUS_ERROR = "error"
def start(config_file, if __name__ == "__main__":
url_root="./translator", main()
host="0.0.0.0",
port=5000,
debug=True):
def prefix_route(route_function, prefix='', mask='{0}{1}'):
def newroute(route, *args, **kwargs):
return route_function(mask.format(prefix, route), *args, **kwargs)
return newroute
app = Flask(__name__)
app.route = prefix_route(app.route, url_root)
translation_server = TranslationServer()
translation_server.start(config_file)
@app.route('/models', methods=['GET'])
def get_models():
out = translation_server.list_models()
return jsonify(out)
@app.route('/health', methods=['GET'])
def health():
out = {}
out['status'] = STATUS_OK
return jsonify(out)
@app.route('/clone_model/<int:model_id>', methods=['POST'])
def clone_model(model_id):
out = {}
data = request.get_json(force=True)
timeout = -1
if 'timeout' in data:
timeout = data['timeout']
del data['timeout']
opt = data.get('opt', None)
try:
model_id, load_time = translation_server.clone_model(
model_id, opt, timeout)
except ServerModelError as e:
out['status'] = STATUS_ERROR
out['error'] = str(e)
else:
out['status'] = STATUS_OK
out['model_id'] = model_id
out['load_time'] = load_time
return jsonify(out)
@app.route('/unload_model/<int:model_id>', methods=['GET'])
def unload_model(model_id):
out = {"model_id": model_id}
try:
translation_server.unload_model(model_id)
out['status'] = STATUS_OK
except Exception as e:
out['status'] = STATUS_ERROR
out['error'] = str(e)
return jsonify(out)
@app.route('/translate', methods=['POST'])
def translate():
inputs = request.get_json(force=True)
out = {}
try:
translation, scores, n_best, times = translation_server.run(inputs)
assert len(translation) == len(inputs)
assert len(scores) == len(inputs)
out = [[{"src": inputs[i]['src'], "tgt": translation[i],
"n_best": n_best,
"pred_score": scores[i]}
for i in range(len(translation))]]
except ServerModelError as e:
out['error'] = str(e)
out['status'] = STATUS_ERROR
return jsonify(out)
@app.route('/to_cpu/<int:model_id>', methods=['GET'])
def to_cpu(model_id):
out = {'model_id': model_id}
translation_server.models[model_id].to_cpu()
out['status'] = STATUS_OK
return jsonify(out)
@app.route('/to_gpu/<int:model_id>', methods=['GET'])
def to_gpu(model_id):
out = {'model_id': model_id}
translation_server.models[model_id].to_gpu()
out['status'] = STATUS_OK
return jsonify(out)
app.run(debug=debug, host=host, port=port, use_reloader=False,
threaded=True)
def _get_parser():
parser = configargparse.ArgumentParser(
config_file_parser_class=configargparse.YAMLConfigFileParser,
description="OpenNMT-py REST Server")
parser.add_argument("--ip", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default="5000")
parser.add_argument("--url_root", type=str, default="/translator")
parser.add_argument("--debug", "-d", action="store_true")
parser.add_argument("--config", "-c", type=str,
default="./available_models/conf.json")
return parser
if __name__ == '__main__':
parser = _get_parser()
args = parser.parse_args()
start(args.config, url_root=args.url_root, host=args.ip, port=args.port,
debug=args.debug)

Просмотреть файл

@ -1,11 +1,45 @@
#!/usr/bin/env python #!/usr/bin/env python
from setuptools import setup, find_packages
from os import path
from setuptools import setup this_directory = path.abspath(path.dirname(__file__))
with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
long_description = f.read()
setup(name='OpenNMT-py', setup(
name='OpenNMT-py',
description='A python implementation of OpenNMT', description='A python implementation of OpenNMT',
version='0.9.1', long_description=long_description,
long_description_content_type='text/markdown',
packages=['onmt', 'onmt.encoders', 'onmt.modules', 'onmt.tests', version='1.1.1',
'onmt.translate', 'onmt.decoders', 'onmt.inputters', packages=find_packages(),
'onmt.models', 'onmt.utils']) project_urls={
"Documentation": "http://opennmt.net/OpenNMT-py/",
"Forum": "http://forum.opennmt.net/",
"Gitter": "https://gitter.im/OpenNMT/OpenNMT-py",
"Source": "https://github.com/OpenNMT/OpenNMT-py/"
},
install_requires=[
"six",
"tqdm~=4.30.0",
"torch>=1.4.0",
"torchtext==0.4.0",
"future",
"configargparse",
"tensorboard>=1.14",
"flask",
"waitress",
"pyonmttok==1.*;platform_system=='Linux'",
"pyyaml",
],
entry_points={
"console_scripts": [
"onmt_server=onmt.bin.server:main",
"onmt_train=onmt.bin.train:main",
"onmt_translate=onmt.bin.translate:main",
"onmt_preprocess=onmt.bin.preprocess:main",
"onmt_release_model=onmt.bin.release_model:main",
"onmt_average_models=onmt.bin.average_models:main"
],
}
)

0
opennmt/tools/apply_bpe.py Normal file → Executable file
Просмотреть файл

42
opennmt/tools/average_models.py Normal file → Executable file
Просмотреть файл

@ -1,45 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import argparse from onmt.bin.average_models import main
import torch
def average_models(model_files):
vocab = None
opt = None
avg_model = None
avg_generator = None
for i, model_file in enumerate(model_files):
m = torch.load(model_file, map_location='cpu')
model_weights = m['model']
generator_weights = m['generator']
if i == 0:
vocab, opt = m['vocab'], m['opt']
avg_model = model_weights
avg_generator = generator_weights
else:
for (k, v) in avg_model.items():
avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1)
for (k, v) in avg_generator.items():
avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1)
final = {"vocab": vocab, "opt": opt, "optim": None,
"generator": avg_generator, "model": avg_model}
return final
def main():
parser = argparse.ArgumentParser(description="")
parser.add_argument("-models", "-m", nargs="+", required=True,
help="List of models")
parser.add_argument("-output", "-o", required=True,
help="Output file")
opt = parser.parse_args()
final = average_models(opt.models)
torch.save(final, opt.output)
if __name__ == "__main__": if __name__ == "__main__":

0
opennmt/tools/bpe_pipeline.sh Normal file → Executable file
Просмотреть файл

Просмотреть файл

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import argparse import argparse
import sys import sys
import os
def read_files_batch(file_list): def read_files_batch(file_list):
@ -10,7 +11,6 @@ def read_files_batch(file_list):
fd_list = [] # File descriptor list fd_list = [] # File descriptor list
exit = False # Flag used for quitting the program in case of error exit = False # Flag used for quitting the program in case of error
try: try:
for filename in file_list: for filename in file_list:
fd_list.append(open(filename)) fd_list.append(open(filename))
@ -51,7 +51,8 @@ def main():
corresponding to the argument 'side'.""") corresponding to the argument 'side'.""")
parser.add_argument("-file", type=str, nargs="+", required=True) parser.add_argument("-file", type=str, nargs="+", required=True)
parser.add_argument("-out_file", type=str, required=True) parser.add_argument("-out_file", type=str, required=True)
parser.add_argument("-side", type=str) parser.add_argument("-side", choices=['src', 'tgt'], help="""Specifies
'src' or 'tgt' side for 'field' file_type.""")
opt = parser.parse_args() opt = parser.parse_args()
@ -72,8 +73,16 @@ def main():
reverse=True): reverse=True):
f.write("{0}\n".format(w)) f.write("{0}\n".format(w))
else: else:
if opt.side not in ['src', 'tgt']:
raise ValueError("If using -file_type='field', specifies "
"'src' or 'tgt' argument for -side.")
import torch import torch
try:
from onmt.inputters.inputter import _old_style_vocab from onmt.inputters.inputter import _old_style_vocab
except ImportError:
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from onmt.inputters.inputter import _old_style_vocab
print("Reading input file...") print("Reading input file...")
if not len(opt.file) == 1: if not len(opt.file) == 1:
raise ValueError("If using -file_type='field', only pass one " raise ValueError("If using -file_type='field', only pass one "

0
opennmt/tools/detokenize.perl Normal file → Executable file
Просмотреть файл

2
opennmt/tools/embeddings_to_torch.py Normal file → Executable file
Просмотреть файл

@ -137,7 +137,7 @@ def main():
logger.info("\t* enc: %d match, %d missing, (%.2f%%)" logger.info("\t* enc: %d match, %d missing, (%.2f%%)"
% calc_vocab_load_stats(enc_vocab, src_vectors)) % calc_vocab_load_stats(enc_vocab, src_vectors))
logger.info("\t* dec: %d match, %d missing, (%.2f%%)" logger.info("\t* dec: %d match, %d missing, (%.2f%%)"
% calc_vocab_load_stats(dec_vocab, src_vectors)) % calc_vocab_load_stats(dec_vocab, tgt_vectors))
# Write to file # Write to file
enc_output_file = opt.output_file + ".enc.pt" enc_output_file = opt.output_file + ".enc.pt"

0
opennmt/tools/learn_bpe.py Normal file → Executable file
Просмотреть файл

0
opennmt/tools/multi-bleu-detok.perl Normal file → Executable file
Просмотреть файл

Просмотреть файл

@ -1,16 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
import argparse from onmt.bin.release_model import main
import torch
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( main()
description="Removes the optim data of PyTorch models")
parser.add_argument("--model", "-m",
help="The model filename (*.pt)", required=True)
parser.add_argument("--output", "-o",
help="The output filename (*.pt)", required=True)
opt = parser.parse_args()
model = torch.load(opt.model)
model['optim'] = None
torch.save(model, opt.output)

0
opennmt/tools/tokenizer.perl Normal file → Executable file
Просмотреть файл

Просмотреть файл

@ -1,200 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
"""Train models.""" from onmt.bin.train import main
import os
import signal
import torch
import onmt.opts as opts
import onmt.utils.distributed
from onmt.utils.misc import set_random_seed
from onmt.utils.logging import init_logger, logger
from onmt.train_single import main as single_main
from onmt.utils.parse import ArgumentParser
from onmt.inputters.inputter import build_dataset_iter, \
load_old_vocab, old_style_vocab, build_dataset_iter_multiple
from itertools import cycle
def main(opt):
ArgumentParser.validate_train_opts(opt)
ArgumentParser.update_model_opts(opt)
ArgumentParser.validate_model_opts(opt)
# Load checkpoint if we resume from a previous training.
if opt.train_from:
logger.info('Loading checkpoint from %s' % opt.train_from)
checkpoint = torch.load(opt.train_from,
map_location=lambda storage, loc: storage)
logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
vocab = checkpoint['vocab']
else:
vocab = torch.load(opt.data + '.vocab.pt')
# check for code where vocab is saved instead of fields
# (in the future this will be done in a smarter way)
if old_style_vocab(vocab):
fields = load_old_vocab(
vocab, opt.model_type, dynamic_dict=opt.copy_attn)
else:
fields = vocab
if len(opt.data_ids) > 1:
train_shards = []
for train_id in opt.data_ids:
shard_base = "train_" + train_id
train_shards.append(shard_base)
train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
else:
if opt.data_ids[0] is not None:
shard_base = "train_" + opt.data_ids[0]
else:
shard_base = "train"
train_iter = build_dataset_iter(shard_base, fields, opt)
nb_gpu = len(opt.gpu_ranks)
if opt.world_size > 1:
queues = []
mp = torch.multiprocessing.get_context('spawn')
semaphore = mp.Semaphore(opt.world_size * opt.queue_size)
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Train with multiprocessing.
procs = []
for device_id in range(nb_gpu):
q = mp.Queue(opt.queue_size)
queues += [q]
procs.append(mp.Process(target=run, args=(
opt, device_id, error_queue, q, semaphore), daemon=True))
procs[device_id].start()
logger.info(" Starting process pid: %d " % procs[device_id].pid)
error_handler.add_child(procs[device_id].pid)
producer = mp.Process(target=batch_producer,
args=(train_iter, queues, semaphore, opt,),
daemon=True)
producer.start()
error_handler.add_child(producer.pid)
for p in procs:
p.join()
producer.terminate()
elif nb_gpu == 1: # case 1 GPU only
single_main(opt, 0)
else: # case only CPU
single_main(opt, -1)
def batch_producer(generator_to_serve, queues, semaphore, opt):
init_logger(opt.log_file)
set_random_seed(opt.seed, False)
# generator_to_serve = iter(generator_to_serve)
def pred(x):
"""
Filters batches that belong only
to gpu_ranks of current node
"""
for rank in opt.gpu_ranks:
if x[0] % opt.world_size == rank:
return True
generator_to_serve = filter(
pred, enumerate(generator_to_serve))
def next_batch(device_id):
new_batch = next(generator_to_serve)
semaphore.acquire()
return new_batch[1]
b = next_batch(0)
for device_id, q in cycle(enumerate(queues)):
b.dataset = None
if isinstance(b.src, tuple):
b.src = tuple([_.to(torch.device(device_id))
for _ in b.src])
else:
b.src = b.src.to(torch.device(device_id))
b.tgt = b.tgt.to(torch.device(device_id))
b.indices = b.indices.to(torch.device(device_id))
b.alignment = b.alignment.to(torch.device(device_id)) \
if hasattr(b, 'alignment') else None
b.src_map = b.src_map.to(torch.device(device_id)) \
if hasattr(b, 'src_map') else None
# hack to dodge unpicklable `dict_keys`
b.fields = list(b.fields)
q.put(b)
b = next_batch(device_id)
def run(opt, device_id, error_queue, batch_queue, semaphore):
""" run process """
try:
gpu_rank = onmt.utils.distributed.multi_init(opt, device_id)
if gpu_rank != opt.gpu_ranks[device_id]:
raise AssertionError("An error occurred in \
Distributed initialization")
single_main(opt, device_id, batch_queue, semaphore)
except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback
error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc()))
class ErrorHandler(object):
"""A class that listens for exceptions in children processes and propagates
the tracebacks to the parent process."""
def __init__(self, error_queue):
""" init error handler """
import signal
import threading
self.error_queue = error_queue
self.children_pids = []
self.error_thread = threading.Thread(
target=self.error_listener, daemon=True)
self.error_thread.start()
signal.signal(signal.SIGUSR1, self.signal_handler)
def add_child(self, pid):
""" error handler """
self.children_pids.append(pid)
def error_listener(self):
""" error listener """
(rank, original_trace) = self.error_queue.get()
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)
def signal_handler(self, signalnum, stackframe):
""" signal handler """
for pid in self.children_pids:
os.kill(pid, signal.SIGINT) # kill children processes
(rank, original_trace) = self.error_queue.get()
msg = """\n\n-- Tracebacks above this line can probably
be ignored --\n\n"""
msg += original_trace
raise Exception(msg)
def _get_parser():
parser = ArgumentParser(description='train.py')
opts.config_opts(parser)
opts.model_opts(parser)
opts.train_opts(parser)
return parser
if __name__ == "__main__": if __name__ == "__main__":
parser = _get_parser() main()
opt = parser.parse_args()
main(opt)

Просмотреть файл

@ -1,52 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- from onmt.bin.translate import main
from __future__ import unicode_literals
from itertools import repeat
from onmt.utils.logging import init_logger
from onmt.utils.misc import split_corpus
from onmt.translate.translator import build_translator
import onmt.opts as opts
from onmt.utils.parse import ArgumentParser
import pickle
def main(opt):
with open('../opt_data', 'wb') as f:
pickle.dump(opt, f)
ArgumentParser.validate_translate_opts(opt)
logger = init_logger(opt.log_file)
translator = build_translator(opt, report_score=True)
src_shards = split_corpus(opt.src, opt.shard_size)
tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
if opt.tgt is not None else repeat(None)
shard_pairs = zip(src_shards, tgt_shards)
for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
logger.info("Translating shard %d." % i)
translator.translate(
src=src_shard,
tgt=tgt_shard,
src_dir=opt.src_dir,
batch_size=opt.batch_size,
attn_debug=opt.attn_debug
)
def _get_parser():
parser = ArgumentParser(description='translate.py')
opts.config_opts(parser)
opts.translate_opts(parser)
return parser
if __name__ == "__main__": if __name__ == "__main__":
parser = _get_parser() main()
opt = parser.parse_args()
main(opt)

Двоичные данные
opt_data

Двоичный файл не отображается.

Просмотреть файл

@ -1,12 +1,6 @@
Django==3.0.3 Django==3.0.3
numpy==1.16.3 numpy
six
tqdm==4.30.*
torch==1.1 -f https://download.pytorch.org/whl/torch_stable.html
git+https://github.com/pytorch/text.git@master#wheel=torchtext
future
channels channels
configargparse
editdistance editdistance
indic_transliteration indic_transliteration
jsonfield jsonfield