зеркало из https://github.com/microsoft/fastseq.git
Fairseq v0.10.2 compatible (#104)
Updated Fastseq to be compatible with the latest version of Fairseq (0.10.2).
This commit is contained in:
Родитель
c70b715dec
Коммит
1974223378
|
@ -16,17 +16,17 @@ Below shows the generation speed gain by using FastSeq.
|
|||
|
||||
| Model | W/O FastSeq (in samples/s) | W/ FastSeq (in samples/s) | Speedup |
|
||||
|------------------|:--------------------------:|:-------------------------:|:-----:|
|
||||
| [ProphetNet](examples/prophetnet/README.md) | 2.8 | 10.7 | 3.8x |
|
||||
| [Bart (`fs`)](examples/bart/README.md) | 2.4 | 25.3 | 10.5x |
|
||||
| [ProphetNet](examples/prophetnet/README.md) | 2.8 | 11.9 | 4.3 |
|
||||
| [Bart (`fs`)](examples/bart/README.md) | 3.3 | 25.1 | 7.7x |
|
||||
| [Bart (`hf`)](examples/bart/README.md#speedup-bart-huggingface-transformers-version-by-using-fastseq) | 2.5 | 12.4 | 5.0x |
|
||||
| [DistilBart (`hf`)](examples/distilbart/README.md) | 3.4 | 18.5 | 5.4x |
|
||||
| [T5 (`hf`)](examples/t5/README.md) | 8.7 | 31.3 | 3.6x |
|
||||
| [WMT16 En-De (`fs`)](examples/wmt/README.md) | 96.0 | 417.0 | 4.3x |
|
||||
| [WMT16 En-De (`fs`)](examples/wmt/README.md) | 144.5 | 422.8 | 2.9x |
|
||||
| [GPT2 (`hf`)](examples/gpt2/README.md) | 3.0 | 16.7 | 5.5x |
|
||||
| [UniLM (`hf`)](examples/unilm/README.md) | 1.7 | 16.4 | 9.6x |
|
||||
|
||||
- All benchmarking experiments run on NVIDIA-V100-16GB with [docker](docker/Dockerfile). Highest speed recorded for each model by tuning batch size. For parameter setting details, click link of corresponding model.
|
||||
- `fs` stands for [Fairseq](https://github.com/pytorch/fairseq) 0.9.0 version, `hf` stands for [Huggingface Transformers](https://github.com/huggingface/transformers) 3.0.2 version.
|
||||
- `fs` stands for [Fairseq](https://github.com/pytorch/fairseq) 0.10.2 version, `hf` stands for [Huggingface Transformers](https://github.com/huggingface/transformers) 3.0.2 version.
|
||||
- Optimizations were automatically applied to all generation/sequence models in Fairseq & Huggingface Transformers. Above only lists a subset of them.
|
||||
|
||||
## How it works?
|
||||
|
|
|
@ -15,13 +15,29 @@ jobs:
|
|||
demands:
|
||||
- agent.name -equals gpu3
|
||||
container:
|
||||
image: adsbrainwestus2.azurecr.io/fastseq:dev-py3
|
||||
image: adsbrainwestus2.azurecr.io/fastseq:dev-py3
|
||||
endpoint: fastseq-acr
|
||||
options: --gpus device=3
|
||||
steps:
|
||||
- script: |
|
||||
#install fastseq
|
||||
pip install --editable .[transformers,fairseq]
|
||||
which pip
|
||||
which python
|
||||
|
||||
echo "******* Installing fairseq *******"
|
||||
pip install fairseq==0.10.2
|
||||
pip show fairseq
|
||||
|
||||
echo "******* Installing transformers *******"
|
||||
pip install transformers
|
||||
pip show transformers
|
||||
|
||||
echo "******* Installing fastseq *******"
|
||||
pip install --editable .
|
||||
pip show fastseq
|
||||
|
||||
echo "******* Adding local bin to path *******"
|
||||
export PATH="$HOME/bin:$HOME/.local/bin:$PATH"
|
||||
|
||||
echo "******* Running fastseq unittests *******"
|
||||
pip install pytorch-transformers==1.0.0
|
||||
|
|
|
@ -112,6 +112,7 @@ for bs in "${bs_list[@]}"; do
|
|||
--no-repeat-ngram-size 3 \
|
||||
--lenpen 2.0 \
|
||||
--use-el-attn \
|
||||
--required-seq-len-multiple 8 \
|
||||
`#--print-alignment` \
|
||||
`#--print-step # KeyError: steps` \
|
||||
--skip-invalid-size-inputs-valid-test $* \
|
||||
|
@ -132,6 +133,7 @@ for bs in "${bs_list[@]}"; do
|
|||
--max-len-b 140 \
|
||||
--no-repeat-ngram-size 3 \
|
||||
--lenpen 2.0 \
|
||||
--required-seq-len-multiple 8 \
|
||||
`#--print-alignment` \
|
||||
`#--print-step # KeyError: steps` \
|
||||
--skip-invalid-size-inputs-valid-test $* \
|
||||
|
@ -140,7 +142,7 @@ for bs in "${bs_list[@]}"; do
|
|||
ret=$?
|
||||
end=`date +%s`
|
||||
runtime=$(($end-$start))
|
||||
tail=`tail -2 $STDOUT_FILE`
|
||||
tail=`tail -3 $STDOUT_FILE`
|
||||
if [[ $ret -eq 0 && $tail == *$mark1* ]]; then
|
||||
samples=`echo $tail | sed 's/.*Translated \([0-9]*\) sentences.*/\1/'`
|
||||
tokens=`echo $tail | sed 's/.*Translated .* sentences (\([0-9]*\) tokens).*/\1/'`
|
||||
|
|
|
@ -34,21 +34,21 @@ grep "bart.large.cnn cnn_dm/len-1024.bin valid " perf \
|
|||
| awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' \
|
||||
| ./range.sh 17.9 18
|
||||
# Speed on V100 16GB 250W
|
||||
grep -E "fairseq_v0.9.0 bart.large.cnn cnn_dm/len-1024.bin valid 32 " perf \
|
||||
grep -E "fairseq_v0.10.2 bart.large.cnn cnn_dm/len-1024.bin valid 32 " perf \
|
||||
| awk '{s+=$13}END{if(NR==0) print -1; else print s/NR}' \
|
||||
| ./range.sh 2.1 2.7
|
||||
grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm/len-1024.bin valid 32 " perf \
|
||||
| ./range.sh 3.1 3.7
|
||||
grep -E "fairseq_v0.10.2\+fastseq_v.* bart.large.cnn cnn_dm/len-1024.bin valid 32 " perf \
|
||||
| awk '{s+=$13}END{print s/NR}' \
|
||||
| ./range.sh 7.8 100
|
||||
grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm/len-1024.bin valid 64 " perf \
|
||||
grep -E "fairseq_v0.10.2\+fastseq_v.* bart.large.cnn cnn_dm/len-1024.bin valid 64 " perf \
|
||||
| awk '{s+=$13}END{print s/NR}' \
|
||||
| ./range.sh 13.0 100
|
||||
grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm/len-1024.bin valid 128 " perf \
|
||||
grep -E "fairseq_v0.10.2\+fastseq_v.* bart.large.cnn cnn_dm/len-1024.bin valid 128 " perf \
|
||||
| awk '{s+=$13}END{print s/NR}' \
|
||||
| ./range.sh 18.1 100
|
||||
grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm/len-1024.bin valid 256 " perf \
|
||||
grep -E "fairseq_v0.10.2\+fastseq_v.* bart.large.cnn cnn_dm/len-1024.bin valid 256 " perf \
|
||||
| awk '{s+=$13}END{print s/NR}' \
|
||||
| ./range.sh 19 100
|
||||
grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm/len-1024.bin valid 320 " perf \
|
||||
grep -E "fairseq_v0.10.2\+fastseq_v.* bart.large.cnn cnn_dm/len-1024.bin valid 320 " perf \
|
||||
| awk '{s+=$13}END{print s/NR}' \
|
||||
| ./range.sh 25 100
|
||||
| ./range.sh 24.5 100
|
||||
|
|
|
@ -8,19 +8,20 @@
|
|||
# <batch-sizes>
|
||||
source utils.sh
|
||||
|
||||
# TODO: update when ProphetNet is compatible with fairseq 0.10.2
|
||||
# Download ProphetNet repo as the baseline if it does not exist
|
||||
prophetnet_repo_path=$CACHE_DIR/ProphetNet
|
||||
git_clone_if_not_in_cache \
|
||||
https://github.com/microsoft/ProphetNet.git \
|
||||
$prophetnet_repo_path
|
||||
|
||||
./benchmark.sh \
|
||||
fairseq \
|
||||
prophetnet_large_160G_cnndm_model \
|
||||
cnn_dm_bert/len-512.bin \
|
||||
valid \
|
||||
64 \
|
||||
--user-dir $prophetnet_repo_path/src/prophetnet/
|
||||
# prophetnet_repo_path=$CACHE_DIR/ProphetNet
|
||||
# git_clone_if_not_in_cache \
|
||||
# https://github.com/microsoft/ProphetNet.git \
|
||||
# $prophetnet_repo_path
|
||||
#
|
||||
# ./benchmark.sh \
|
||||
# fairseq \
|
||||
# prophetnet_large_160G_cnndm_model \
|
||||
# cnn_dm_bert/len-512.bin \
|
||||
# valid \
|
||||
# 64 \
|
||||
# --user-dir $prophetnet_repo_path/src/prophetnet/
|
||||
./benchmark.sh \
|
||||
fairseq+fastseq \
|
||||
prophetnet_large_160G_cnndm_model \
|
||||
|
@ -33,18 +34,19 @@ grep "prophetnet_large_160G_cnndm_model cnn_dm_bert/len-512.bin valid" perf \
|
|||
| awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' \
|
||||
| ./range.sh 19.1 19.2
|
||||
# # Speed on V100 16GB 250W
|
||||
grep -E "fairseq_v0.9.0 prophetnet_large_160G_cnndm_model cnn_dm_bert/len-512.bin valid 32 " perf \
|
||||
| awk '{s+=$13}END{if(NR==0) print -1; else print s/NR}' \
|
||||
| ./range.sh 2 3
|
||||
grep -E "fairseq_v0.9.0 prophetnet_large_160G_cnndm_model cnn_dm_bert/len-512.bin valid 64 " perf \
|
||||
| awk '{s+=$13}END{if(NR==0) print -1; else print s/NR}' \
|
||||
| ./range.sh 2 3
|
||||
grep -E "fairseq_v0.9.0\+fastseq_v.* prophetnet_large_160G_cnndm_model cnn_dm_bert/len-512.bin valid 32 " perf \
|
||||
# TODO: update when ProphetNet is compatible with fairseq 0.10.2
|
||||
# grep -E "fairseq_v0.10.2 prophetnet_large_160G_cnndm_model cnn_dm_bert/len-512.bin valid 32 " perf \
|
||||
# | awk '{s+=$13}END{if(NR==0) print -1; else print s/NR}' \
|
||||
# | ./range.sh 2 3
|
||||
# grep -E "fairseq_v0.10.2 prophetnet_large_160G_cnndm_model cnn_dm_bert/len-512.bin valid 64 " perf \
|
||||
# | awk '{s+=$13}END{if(NR==0) print -1; else print s/NR}' \
|
||||
# | ./range.sh 2 3
|
||||
grep -E "fairseq_v0.10.2\+fastseq_v.* prophetnet_large_160G_cnndm_model cnn_dm_bert/len-512.bin valid 32 " perf \
|
||||
| awk '{s+=$13}END{if(NR==0) print -1; else print s/NR}' \
|
||||
| ./range.sh 5.7 6.5
|
||||
grep -E "fairseq_v0.9.0\+fastseq_v.* prophetnet_large_160G_cnndm_model cnn_dm_bert/len-512.bin valid 64 " perf \
|
||||
grep -E "fairseq_v0.10.2\+fastseq_v.* prophetnet_large_160G_cnndm_model cnn_dm_bert/len-512.bin valid 64 " perf \
|
||||
| awk '{s+=$13}END{print s/NR}' \
|
||||
| ./range.sh 7.5 10
|
||||
grep -E "fairseq_v0.9.0\+fastseq_v.* prophetnet_large_160G_cnndm_model cnn_dm_bert/len-512.bin valid 128 " perf \
|
||||
| ./range.sh 8 10.5
|
||||
grep -E "fairseq_v0.10.2\+fastseq_v.* prophetnet_large_160G_cnndm_model cnn_dm_bert/len-512.bin valid 128 " perf \
|
||||
| awk '{s+=$13}END{print s/NR}'\
|
||||
| ./range.sh 10 15
|
||||
|
|
|
@ -27,15 +27,15 @@ grep " wmt16.en.de.32k wmt16_en_de_bpe32k/bin valid " perf \
|
|||
| awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' \
|
||||
| ./range.sh 0.05 0.07
|
||||
# Speed on V100 16GB 250W
|
||||
grep -E "fairseq_v0.9.0 wmt16.en.de.32k wmt16_en_de_bpe32k/bin valid 256 " perf \
|
||||
grep -E "fairseq_v0.10.2 wmt16.en.de.32k wmt16_en_de_bpe32k/bin valid 256 " perf \
|
||||
| awk '{s+=$13}END{if(NR==0) print -1; else print s/NR}' \
|
||||
| ./range.sh 93 100
|
||||
grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin valid 256 " perf \
|
||||
| ./range.sh 100 150
|
||||
grep -E "fairseq_v0.10.2\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin valid 256 " perf \
|
||||
| awk '{s+=$13}END{print s/NR}' \
|
||||
| ./range.sh 350 1000
|
||||
grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin valid 512 " perf \
|
||||
grep -E "fairseq_v0.10.2\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin valid 512 " perf \
|
||||
| awk '{s+=$13}END{print s/NR}' \
|
||||
| ./range.sh 390 1000
|
||||
grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin valid 1024 " perf \
|
||||
grep -E "fairseq_v0.10.2\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin valid 1024 " perf \
|
||||
| awk '{s+=$13}END{print s/NR}' \
|
||||
| ./range.sh 405 1000
|
||||
|
|
|
@ -49,7 +49,7 @@ RUN pip install --upgrade pip && \
|
|||
pip install requests>=v2.24.0 && \
|
||||
pip install gitpython>=v3.1.7 && \
|
||||
pip install rouge_score==v0.0.4 && \
|
||||
pip install fairseq==v0.9.0 && \
|
||||
pip install fairseq==v0.10.2 && \
|
||||
pip install transformers==v3.0.2 && \
|
||||
pip install pytorch-transformers==1.0.0
|
||||
|
||||
|
|
|
@ -1,38 +1,18 @@
|
|||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"name": "EL-attention_Demo.ipynb",
|
||||
"provenance": [],
|
||||
"collapsed_sections": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"id": "9xZnWgQ1o7yW",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "9xZnWgQ1o7yW",
|
||||
"outputId": "fb630728-01e2-4cbc-c509-4874c97d8d94"
|
||||
},
|
||||
"source": [
|
||||
"!nvidia-smi"
|
||||
],
|
||||
"execution_count": 1,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Tue May 25 01:05:18 2021 \n",
|
||||
|
@ -55,9 +35,11 @@
|
|||
"|=============================================================================|\n",
|
||||
"| No running processes found |\n",
|
||||
"+-----------------------------------------------------------------------------+\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!nvidia-smi"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -71,6 +53,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
|
@ -78,17 +61,9 @@
|
|||
"id": "ipS-Ptqupa1X",
|
||||
"outputId": "d7cd49e3-4183-40b2-fcbc-35706a606c93"
|
||||
},
|
||||
"source": [
|
||||
"!pip install fairseq==v0.9.0\n",
|
||||
"!pip install git+https://github.com/microsoft/fastseq.git\n",
|
||||
"\n",
|
||||
"#!git clone https://github.com/NVIDIA/apex\n",
|
||||
"#!sed -i 's/if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):/if False:/' ./apex/setup.py\n",
|
||||
"#!pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./apex"
|
||||
],
|
||||
"execution_count": 2,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Collecting fairseq==v0.9.0\n",
|
||||
|
@ -161,13 +136,21 @@
|
|||
"\u001b[31mERROR: botocore 1.20.79 has requirement urllib3<1.27,>=1.25.4, but you'll have urllib3 1.24.3 which is incompatible.\u001b[0m\n",
|
||||
"Installing collected packages: rouge-score, sentencepiece, jmespath, botocore, s3transfer, boto3, pytorch-transformers, fastseq\n",
|
||||
"Successfully installed boto3-1.17.79 botocore-1.20.79 fastseq-0.0.4 jmespath-0.10.0 pytorch-transformers-1.0.0 rouge-score-0.0.4 s3transfer-0.4.2 sentencepiece-0.1.95\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!pip install fairseq==v0.10.2\n",
|
||||
"!pip install git+https://github.com/microsoft/fastseq.git\n",
|
||||
"\n",
|
||||
"#!git clone https://github.com/NVIDIA/apex\n",
|
||||
"#!sed -i 's/if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):/if False:/' ./apex/setup.py\n",
|
||||
"#!pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./apex"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
|
@ -175,23 +158,9 @@
|
|||
"id": "nsnnkjHV4VHq",
|
||||
"outputId": "751fe3ec-cf88-4861-8734-3c5b358d34b4"
|
||||
},
|
||||
"source": [
|
||||
"!mkdir -p data\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/for_bart/test.source -P data/\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/for_bart/test.target -P data/\n",
|
||||
"\n",
|
||||
"!mkdir -p data/bin\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/len-1024.bin/test.source-target.source.bin -P data/bin/\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/len-1024.bin/test.source-target.source.idx -P data/bin/\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/len-1024.bin/test.source-target.target.bin -P data/bin/\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/len-1024.bin/test.source-target.target.idx -P data/bin/\n",
|
||||
"\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/len-1024.bin/dict.source.txt -P data/bin/\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/len-1024.bin/dict.target.txt -P data/bin/\n"
|
||||
],
|
||||
"execution_count": 3,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"--2021-05-25 01:07:00-- https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/for_bart/test.source\n",
|
||||
|
@ -282,27 +251,37 @@
|
|||
"\n",
|
||||
"2021-05-25 01:07:09 (1.53 MB/s) - ‘data/bin/dict.target.txt’ saved [603290/603290]\n",
|
||||
"\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!mkdir -p data\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/for_bart/test.source -P data/\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/for_bart/test.target -P data/\n",
|
||||
"\n",
|
||||
"!mkdir -p data/bin\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/len-1024.bin/test.source-target.source.bin -P data/bin/\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/len-1024.bin/test.source-target.source.idx -P data/bin/\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/len-1024.bin/test.source-target.target.bin -P data/bin/\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/len-1024.bin/test.source-target.target.idx -P data/bin/\n",
|
||||
"\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/len-1024.bin/dict.source.txt -P data/bin/\n",
|
||||
"!wget https://fastseq.blob.core.windows.net/data/tasks/cnn_dm/len-1024.bin/dict.target.txt -P data/bin/\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"id": "KJ40BxorphD2",
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "KJ40BxorphD2",
|
||||
"outputId": "6720f603-f167-4e9c-9bc6-7b709105939e"
|
||||
},
|
||||
"source": [
|
||||
"!wget 'https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz'\n",
|
||||
"!tar xvzf 'bart.large.cnn.tar.gz'"
|
||||
],
|
||||
"execution_count": 4,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"--2021-05-25 01:07:09-- https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz\n",
|
||||
|
@ -321,9 +300,12 @@
|
|||
"bart.large.cnn/model.pt\n",
|
||||
"bart.large.cnn/dict.source.txt\n",
|
||||
"bart.large.cnn/dict.target.txt\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!wget 'https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz'\n",
|
||||
"!tar xvzf 'bart.large.cnn.tar.gz'"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -346,6 +328,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
|
@ -353,14 +336,9 @@
|
|||
"id": "hZs68SvhVD7p",
|
||||
"outputId": "65f1af03-6806-46bf-dec1-71c79095c002"
|
||||
},
|
||||
"source": [
|
||||
"!time fastseq-generate-for-fairseq data/bin/ --path bart.large.cnn/model.pt --use-el-attn --fp16 --task translation --batch-size 256 --gen-subset test --truncate-source --bpe gpt2 --beam 4 --min-len 55 --max-len-b 140 --no-repeat-ngram-size 3 --lenpen 2.0 --skip-invalid-size-inputs-valid-test > bart.el-attention_256.txt\n",
|
||||
"# uncomment above line to run\n",
|
||||
"# Expected time: (real\t24m23.314s for inference on T5 in colab)"
|
||||
],
|
||||
"execution_count": 5,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING 2021-05-25 01:11:15,116 /usr/local/lib/python3.7/dist-packages/fastseq/models/__init__.py:17] transformers can not be imported.\n",
|
||||
|
@ -375,9 +353,13 @@
|
|||
"real\t25m19.342s\n",
|
||||
"user\t29m29.255s\n",
|
||||
"sys\t0m26.392s\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!time fastseq-generate-for-fairseq data/bin/ --path bart.large.cnn/model.pt --use-el-attn --fp16 --task translation --batch-size 256 --gen-subset test --truncate-source --bpe gpt2 --beam 4 --min-len 55 --max-len-b 140 --no-repeat-ngram-size 3 --lenpen 2.0 --skip-invalid-size-inputs-valid-test > bart.el-attention_256.txt\n",
|
||||
"# uncomment above line to run\n",
|
||||
"# Expected time: (real\t24m23.314s for inference on T5 in colab)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -391,28 +373,27 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"id": "LVHnPI7CVNdZ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# !time fairseq-generate data/bin/ --path bart.large.cnn/model.pt --fp16 --task translation --batch-size 32 --gen-subset test --truncate-source --bpe gpt2 --beam 4 --min-len 55 --max-len-b 140 --no-repeat-ngram-size 3 --lenpen 2.0 --skip-invalid-size-inputs-valid-test > bart.multihead-attention_32.txt\n",
|
||||
"# uncomment above line to run\n",
|
||||
"# fairseq 0.9.0 is compariable with Pytorch 1.6, but in colab its version is 1.8.1"
|
||||
],
|
||||
"execution_count": 6,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "sCfbk3wuhtYO"
|
||||
},
|
||||
"source": [
|
||||
""
|
||||
]
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
|
@ -420,23 +401,23 @@
|
|||
"id": "xuXz8Zzwhs2u",
|
||||
"outputId": "2e9a521f-eae5-487c-cc44-3a76ca8d2c4f"
|
||||
},
|
||||
"source": [
|
||||
"!ls "
|
||||
],
|
||||
"execution_count": 7,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"bart.el-attention_256.txt bart.large.cnn.tar.gz sample_data\n",
|
||||
"bart.large.cnn\t\t data\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!ls "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
|
@ -444,25 +425,24 @@
|
|||
"id": "8xpZw27qjkoe",
|
||||
"outputId": "8182ebeb-6e08-4be2-a846-fb9bde932152"
|
||||
},
|
||||
"source": [
|
||||
"!echo 'Inference speed when using el attention'\n",
|
||||
"!tail -n 2 bart.el-attention_256.txt\n",
|
||||
"\n",
|
||||
"!echo 'Inference speed when using multihead attention'\n",
|
||||
"#!tail -n 2 bart.multihead-attention_32.txt"
|
||||
],
|
||||
"execution_count": 8,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Inference speed when using el attention\n",
|
||||
"| Translated 11490 sentences (944820 tokens) in 1472.9s (7.80 sentences/s, 641.48 tokens/s)\n",
|
||||
"| Generate test with beam=4: BLEU4 = 17.18, 36.9/19.4/13.1/9.3 (BP=1.000, ratio=1.203, syslen=933330, reflen=776133)\n",
|
||||
"Inference speed when using multihead attention\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!echo 'Inference speed when using el attention'\n",
|
||||
"!tail -n 2 bart.el-attention_256.txt\n",
|
||||
"\n",
|
||||
"!echo 'Inference speed when using multihead attention'\n",
|
||||
"#!tail -n 2 bart.multihead-attention_32.txt"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -496,6 +476,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
|
@ -503,12 +484,9 @@
|
|||
"id": "idZceGNOz5tF",
|
||||
"outputId": "081e2af7-4957-42cf-9660-3679de16a832"
|
||||
},
|
||||
"source": [
|
||||
"!wget https://raw.githubusercontent.com/microsoft/fastseq/EL-attention-doc/examples/EL-attention/summarize.py"
|
||||
],
|
||||
"execution_count": 9,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"--2021-05-25 01:36:33-- https://raw.githubusercontent.com/microsoft/fastseq/EL-attention-doc/examples/EL-attention/summarize.py\n",
|
||||
|
@ -522,13 +500,16 @@
|
|||
"\n",
|
||||
"2021-05-25 01:36:33 (63.8 MB/s) - ‘summarize.py’ saved [3840/3840]\n",
|
||||
"\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!wget https://raw.githubusercontent.com/microsoft/fastseq/EL-attention-doc/examples/EL-attention/summarize.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
|
@ -536,21 +517,9 @@
|
|||
"id": "FsuvH58auvsk",
|
||||
"outputId": "049fa56a-8921-4df6-ca84-8f4c3d7184c6"
|
||||
},
|
||||
"source": [
|
||||
"!time python summarize.py \\\n",
|
||||
" --model-dir bart.large.cnn/ \\\n",
|
||||
" --model-file model.pt \\\n",
|
||||
" --src data/test.source \\\n",
|
||||
" --bsz 320 \\\n",
|
||||
" --out 320_test.hypo \\\n",
|
||||
" --use-el-attn \\\n",
|
||||
" --n 3200\n",
|
||||
"\n",
|
||||
" # Expected time: (real\t7m36.176s for inference 3200 samples on T5 in colab)"
|
||||
],
|
||||
"execution_count": 10,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"WARNING 2021-05-25 01:36:34,967 /usr/local/lib/python3.7/dist-packages/fastseq/models/__init__.py:17] transformers can not be imported.\n",
|
||||
|
@ -579,13 +548,25 @@
|
|||
"real\t9m49.515s\n",
|
||||
"user\t7m50.919s\n",
|
||||
"sys\t0m8.170s\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!time python summarize.py \\\n",
|
||||
" --model-dir bart.large.cnn/ \\\n",
|
||||
" --model-file model.pt \\\n",
|
||||
" --src data/test.source \\\n",
|
||||
" --bsz 320 \\\n",
|
||||
" --out 320_test.hypo \\\n",
|
||||
" --use-el-attn \\\n",
|
||||
" --n 3200\n",
|
||||
"\n",
|
||||
" # Expected time: (real\t7m36.176s for inference 3200 samples on T5 in colab)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
|
@ -593,12 +574,9 @@
|
|||
"id": "8wWy6E7410Zt",
|
||||
"outputId": "3bcd0637-f5bf-466f-ebe0-f926e1a2e38f"
|
||||
},
|
||||
"source": [
|
||||
"!head 320_test.hypo"
|
||||
],
|
||||
"execution_count": 11,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"A French prosecutor says he is not aware of any video footage from on board the plane. German daily Bild and Paris Match claim to have found a cell phone video of the crash. A French Gendarmerie spokesman calls the reports \"completely wrong\" and \"unwarranted\" German airline Lufthansa says co-pilot Andreas Lubitz battled depression years before he took controls.\n",
|
||||
|
@ -611,22 +589,24 @@
|
|||
"Theia, a one-year-old bully breed mix, was hit by a car and buried in a field. She managed to stagger to a nearby farm, dirt-covered and emaciated. She suffered a dislocated jaw, leg injuries and a caved-in sinus cavity. A fundraising page has raised more than $10,000 for her care.\n",
|
||||
"Mohammad Javad Zarif is the Iranian foreign minister. He has been John Kerry's opposite number in securing a breakthrough in nuclear talks. He received a hero's welcome as he arrived in Iran on a sunny Friday morning. But there are some facts about Zarif that are less well-known.\n",
|
||||
"Bob Barker returned to \"The Price Is Right\" for the first time in eight years. The 91-year-old hosted the show for 35 years before stepping down in 2007. He handled the first price-guessing game, the classic \"Lucky Seven,\" before turning hosting duties over to Drew Carey.\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!head 320_test.hypo"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"id": "nmfO-Pj-7sYa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!cp 320_test.hypo test.hypo\n",
|
||||
"!cp data/test.target test.target"
|
||||
],
|
||||
"execution_count": 12,
|
||||
"outputs": []
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
@ -648,6 +628,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
|
@ -655,6 +636,19 @@
|
|||
"id": "D4YkacBQyyke",
|
||||
"outputId": "81b3db13-85e7-4d7b-9aaa-54fbd09e2ded"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Error: Could not find or load main class edu.stanford.nlp.process.PTBTokenizer\n",
|
||||
"Caused by: java.lang.ClassNotFoundException: edu.stanford.nlp.process.PTBTokenizer\n",
|
||||
"Error: Could not find or load main class edu.stanford.nlp.process.PTBTokenizer\n",
|
||||
"Caused by: java.lang.ClassNotFoundException: edu.stanford.nlp.process.PTBTokenizer\n",
|
||||
"/bin/bash: files2rouge: command not found\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar\n",
|
||||
"\n",
|
||||
|
@ -663,32 +657,34 @@
|
|||
"!cat test.target | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.target\n",
|
||||
"!files2rouge test.hypo.tokenized test.hypo.target\n",
|
||||
"# Expected output: (ROUGE-2 Average_F: 0.21227)"
|
||||
],
|
||||
"execution_count": 13,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Error: Could not find or load main class edu.stanford.nlp.process.PTBTokenizer\n",
|
||||
"Caused by: java.lang.ClassNotFoundException: edu.stanford.nlp.process.PTBTokenizer\n",
|
||||
"Error: Could not find or load main class edu.stanford.nlp.process.PTBTokenizer\n",
|
||||
"Caused by: java.lang.ClassNotFoundException: edu.stanford.nlp.process.PTBTokenizer\n",
|
||||
"/bin/bash: files2rouge: command not found\n"
|
||||
],
|
||||
"name": "stdout"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"id": "5TfX3oz97yC_"
|
||||
},
|
||||
"source": [
|
||||
""
|
||||
],
|
||||
"execution_count": 13,
|
||||
"outputs": []
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "EL-attention_Demo.ipynb",
|
||||
"provenance": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
|
|
|
@ -12,8 +12,8 @@ BART is sequence-to-sequence model trained with denoising as pretraining objecti
|
|||
|
||||
| BatchSize | 32 | 64 | 128 | 320 |
|
||||
|:----------------:|:-------------:|:---------------:|:--------------:|:--------------:|
|
||||
| fairseq-0.9.0 | 2.4 samples/s | OOM | OOM | OOM |
|
||||
| above + fastseq | 8.1 samples/s | 13.3 samples/s | 18.4 samples/s | 25.3 samples/s |
|
||||
| fairseq-0.10.2 | 3.3 samples/s | OOM | OOM | OOM |
|
||||
| above + fastseq | 10.7 samples/s | 17.1 samples/s | 21.8 samples/s | 25.1 samples/s |
|
||||
|
||||
### Model
|
||||
|
||||
|
|
|
@ -10,8 +10,8 @@ A pre-trained language model for sequence-to-sequence learning with a novel self
|
|||
|
||||
| BatchSize | 32 | 64 | 128 |
|
||||
|:--------------------:|:-------------:|:---------------:|:--------------:|
|
||||
| prophetnet | 2.4 samples/s | 2.8 samples/s | OOM |
|
||||
| above + fastseq | 6.0 samples/s | 7.6 samples/s | 10.7 samples/s |
|
||||
| prophetnet (fs 0.9.0) | 2.4 samples/s | 2.8 samples/s | OOM |
|
||||
| above + fastseq | 6.1 samples/s | 9.1 samples/s | 11.9 samples/s |
|
||||
|
||||
|
||||
### Model
|
||||
|
|
|
@ -7,8 +7,8 @@ https://arxiv.org/abs/1806.00187
|
|||
|
||||
| BatchSize | 256 | 512 | 1024 |
|
||||
|:----------------:|:--------------:|:--------------:|:--------------:|
|
||||
| fairseq-0.9.0 | 96 samples/s | OOM | OOM |
|
||||
| above + fastseq | 350 samples/s | 400 samples/s | 417 samples/s |
|
||||
| fairseq-0.10.2 | 144.5 samples/s | OOM | OOM |
|
||||
| above + fastseq | 364.1 samples/s | 402.1 samples/s | 422.8 samples/s |
|
||||
|
||||
### Training a new model on WMT'16 En-De
|
||||
|
||||
|
|
|
@ -15,15 +15,15 @@ FASTSEQ_UNITTEST_LOG_XML_DIR = os.getenv(
|
|||
FASTSEQ_LOG_FORMAT = (
|
||||
'%(levelname)s %(asctime)s %(pathname)s:%(lineno)d] %(message)s')
|
||||
|
||||
FASTSEQ_VERSION = '0.0.4'
|
||||
FASTSEQ_VERSION = '0.2.0'
|
||||
|
||||
# supported versions of transformers
|
||||
MIN_TRANSFORMERS_VERSION = '3.0.2'
|
||||
MAX_TRANSFORMER_VERSION = '3.0.2'
|
||||
|
||||
# supported versions of fairseq
|
||||
MIN_FAIRSEQ_VERSION = '0.9.0'
|
||||
MAX_FAIRSEQ_VERSION = '0.9.0'
|
||||
MIN_FAIRSEQ_VERSION = '0.10.0'
|
||||
MAX_FAIRSEQ_VERSION = '0.10.2'
|
||||
|
||||
#Set following variable to use Efficient-Lossless Attention
|
||||
USE_EL_ATTN = True if os.getenv('USE_EL_ATTN', '0') == '1' else False
|
||||
|
|
|
@ -29,7 +29,7 @@ class BertDictionary(Dictionary):
|
|||
bos='<s>',
|
||||
extra_special_symbols=None,
|
||||
):
|
||||
super().__init__(pad, eos, unk, bos, extra_special_symbols)
|
||||
super().__init__(pad=pad, eos=eos, unk=unk, bos=bos, extra_special_symbols=extra_special_symbols)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, filename):
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
"""Hub interface for ProphetNet"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from fastseq.logging import get_logger
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
@ -16,6 +18,7 @@ import torch.nn.functional as F
|
|||
from fairseq import utils
|
||||
from fairseq.data import encoders
|
||||
|
||||
logger = get_logger(__name__, logging.INFO)
|
||||
|
||||
class ProphetNetHubInterface(nn.Module):
|
||||
"""A simple PyTorch Hub interface to BART.
|
||||
|
@ -127,7 +130,7 @@ class ProphetNetHubInterface(nn.Module):
|
|||
gen_args.beam = beam
|
||||
for k, v in kwargs.items():
|
||||
setattr(gen_args, k, v)
|
||||
generator = self.task.build_generator(gen_args)
|
||||
generator = self.task.build_generator([self.model], gen_args)
|
||||
translations = self.task.inference_step(
|
||||
generator,
|
||||
[self.model],
|
||||
|
@ -137,7 +140,7 @@ class ProphetNetHubInterface(nn.Module):
|
|||
|
||||
if verbose:
|
||||
src_str_with_unk = self.string(tokens)
|
||||
print('S\t{}'.format(src_str_with_unk))
|
||||
logger.info("S\t{}".format(src_str_with_unk))
|
||||
|
||||
def getarg(name, default):
|
||||
return getattr(gen_args, name, getattr(self.args, name, default))
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
"""NgramMultiheadAttention"""
|
||||
|
||||
import math
|
||||
from fairseq.incremental_decoding_utils import with_incremental_state
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Parameter
|
||||
|
@ -30,7 +31,7 @@ def ngram_attention_bias(length, num_skip):
|
|||
bias_result.append(bias_n_skip)
|
||||
return torch.from_numpy(np.array(bias_result, dtype=np.float32))
|
||||
|
||||
|
||||
@with_incremental_state
|
||||
class NgramMultiheadAttention(nn.Module):
|
||||
"""Multi-headed attention.
|
||||
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,27 +1,52 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Optimize fairseq-generate (v0.9.0)"""
|
||||
"""Optimize fairseq-generate (v0.10.2)"""
|
||||
|
||||
import ast
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from itertools import chain
|
||||
from multiprocessing import Queue, JoinableQueue
|
||||
from torch.multiprocessing import Process
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq_cli.generate import main
|
||||
from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils
|
||||
from fairseq.data import encoders
|
||||
from fairseq.meters import StopwatchMeter, TimeMeter
|
||||
from fairseq.options import add_generation_args
|
||||
from fairseq.utils import apply_to_sample
|
||||
|
||||
from fairseq import scoring, checkpoint_utils, tasks, utils
|
||||
from fairseq.logging import progress_bar
|
||||
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
||||
from fastseq.utils.api_decorator import replace
|
||||
from fairseq.options import add_generation_args
|
||||
|
||||
GENERATE_FINISHED = "done"
|
||||
POSTPROCESS_FINISHED = None
|
||||
|
||||
original_add_generation_args = add_generation_args
|
||||
|
||||
@replace(add_generation_args)
|
||||
def add_generation_args_v2(parser):
|
||||
group = original_add_generation_args(parser)
|
||||
# fmt: off
|
||||
group.add_argument(
|
||||
'--postprocess-workers',
|
||||
default=1,
|
||||
type=int,
|
||||
choices=range(1, 128, 1),
|
||||
metavar='N',
|
||||
help='number of worker for post process')
|
||||
group.add_argument(
|
||||
'--decode-hypothesis',
|
||||
action="store_true",
|
||||
default=True)
|
||||
group.add_argument(
|
||||
'--use-el-attn',
|
||||
action='store_true',
|
||||
help='Use Efficient Lossless Attention optimization ? ')
|
||||
# fmt: on
|
||||
|
||||
def move_to_cpu(sample):
|
||||
def _move_to_cpu(tensor):
|
||||
# PyTorch has poor support for half tensors (float16) on CPU.
|
||||
|
@ -33,12 +58,15 @@ def move_to_cpu(sample):
|
|||
|
||||
return apply_to_sample(_move_to_cpu, sample)
|
||||
|
||||
def convert_base_e_to_base_2(t):
|
||||
return t / math.log(2)
|
||||
|
||||
class IOProcess(Process):
|
||||
"""
|
||||
Single process to hanlde IO and compute metrics
|
||||
Single process to handle IO and compute metrics
|
||||
"""
|
||||
def __init__(self, args, task, message_queue):
|
||||
|
||||
def __init__(self, args, task, message_queue, output_file):
|
||||
"""
|
||||
Process to handle IO and compute metrics
|
||||
|
||||
|
@ -50,18 +78,14 @@ class IOProcess(Process):
|
|||
"""
|
||||
super(IOProcess, self).__init__()
|
||||
self.tgt_dict = task.target_dictionary
|
||||
|
||||
|
||||
# Generate and compute BLEU score
|
||||
if args.sacrebleu:
|
||||
self.scorer = bleu.SacrebleuScorer()
|
||||
else:
|
||||
self.scorer = bleu.Scorer(self.tgt_dict.pad(), self.tgt_dict.eos(),
|
||||
self.tgt_dict.unk())
|
||||
|
||||
self.scorer = scoring.build_scorer(args, self.tgt_dict)
|
||||
self.args = args
|
||||
self.message_queue = message_queue
|
||||
self.has_target = False
|
||||
|
||||
self.output_file = output_file
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
msg = self.message_queue.get()
|
||||
|
@ -74,23 +98,27 @@ class IOProcess(Process):
|
|||
self.has_target = True
|
||||
elif msg == GENERATE_FINISHED:
|
||||
if self.has_target:
|
||||
print('| Generate {} with beam={}: {}'.format(
|
||||
self.args.gen_subset, self.args.beam,
|
||||
self.scorer.result_string()))
|
||||
if self.args.bpe and not self.args.sacrebleu:
|
||||
if self.args.remove_bpe:
|
||||
print("BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization")
|
||||
else:
|
||||
print("If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization")
|
||||
print("Generate {} with beam={}: {}".format(
|
||||
self.args.gen_subset, self.args.beam, self.scorer.result_string()),
|
||||
file=self.output_file,)
|
||||
break
|
||||
else:
|
||||
print(msg)
|
||||
print(msg, file = self.output_file)
|
||||
self.message_queue.task_done()
|
||||
self.message_queue.close()
|
||||
self.message_queue.join_thread()
|
||||
|
||||
|
||||
class PostProcess(Process):
|
||||
'''
|
||||
Use multiple process to do detokenize
|
||||
'''
|
||||
"""
|
||||
Use multiple processes to do detokenization
|
||||
"""
|
||||
|
||||
def __init__(self, args, task, data_queue, message_queue):
|
||||
def __init__(self, args, task, data_queue, message_queue, generator):
|
||||
"""
|
||||
Handle detokenize and belu score computation
|
||||
|
||||
|
@ -115,19 +143,15 @@ class PostProcess(Process):
|
|||
self.align_dict = utils.load_align_dict(args.replace_unk)
|
||||
|
||||
# Generate and compute BLEU score
|
||||
if args.sacrebleu:
|
||||
self.scorer = bleu.SacrebleuScorer()
|
||||
else:
|
||||
self.scorer = bleu.Scorer(self.tgt_dict.pad(), self.tgt_dict.eos(),
|
||||
self.tgt_dict.unk())
|
||||
|
||||
self.scorer = scoring.build_scorer(args, self.tgt_dict)
|
||||
self.args = args
|
||||
self.task = task
|
||||
self.data_queue = data_queue
|
||||
self.message_queue = message_queue
|
||||
self.generator = generator
|
||||
if args.decode_hypothesis:
|
||||
self.tokenizer = encoders.build_tokenizer(args)
|
||||
self.bpe = encoders.build_bpe(args)
|
||||
self.tokenizer = task.build_tokenizer(args)
|
||||
self.bpe = task.build_bpe(args)
|
||||
|
||||
def _decode(self, x):
|
||||
if self.bpe is not None:
|
||||
|
@ -136,21 +160,34 @@ class PostProcess(Process):
|
|||
x = self.tokenizer.decode(x)
|
||||
return x
|
||||
|
||||
def _get_symbols_to_strip_from_output(self, generator):
|
||||
if hasattr(generator, "symbols_to_strip_from_output"):
|
||||
return generator.symbols_to_strip_from_output
|
||||
else:
|
||||
return {generator.eos}
|
||||
|
||||
def _detokenize(self, sample, hypos):
|
||||
""" detokenize and compute BELU """
|
||||
"""
|
||||
Detokenize and compute BELU
|
||||
"""
|
||||
message_list = []
|
||||
for i, sample_id in enumerate(sample['id'].tolist()):
|
||||
has_target = sample['target'] is not None
|
||||
|
||||
# Remove padding
|
||||
src_tokens = utils.strip_pad(
|
||||
sample['net_input']['src_tokens'][i, :], self.tgt_dict.pad())
|
||||
if "src_tokens" in sample["net_input"]:
|
||||
src_tokens = utils.strip_pad(
|
||||
sample["net_input"]["src_tokens"][i, :], self.tgt_dict.pad()
|
||||
)
|
||||
else:
|
||||
src_tokens = None
|
||||
target_tokens = None
|
||||
if has_target:
|
||||
target_tokens = utils.strip_pad(sample['target'][i, :],
|
||||
self.tgt_dict.pad()).int()
|
||||
target_tokens = (
|
||||
utils.strip_pad(sample["target"][i, :], self.tgt_dict.pad()).int().cpu()
|
||||
)
|
||||
|
||||
# Either retrieve the original sentences or regenerate them from tokens.
|
||||
# Either retrieve the original sentences or regenerate them from tokens
|
||||
if self.align_dict is not None:
|
||||
src_str = self.task.dataset(
|
||||
self.args.gen_subset).src.get_original_text(sample_id)
|
||||
|
@ -164,8 +201,11 @@ class PostProcess(Process):
|
|||
src_str = ""
|
||||
if has_target:
|
||||
target_str = self.tgt_dict.string(
|
||||
target_tokens, self.args.remove_bpe, escape_unk=True)
|
||||
|
||||
target_tokens,
|
||||
self.args.remove_bpe,
|
||||
escape_unk = True,
|
||||
extra_symbols_to_ignore = self._get_symbols_to_strip_from_output(self.generator),
|
||||
)
|
||||
if not self.args.quiet:
|
||||
if self.src_dict is not None:
|
||||
if self.args.decode_hypothesis:
|
||||
|
@ -181,77 +221,74 @@ class PostProcess(Process):
|
|||
else:
|
||||
message_list.append('T-{}\t{}'.format(
|
||||
sample_id, target_str))
|
||||
|
||||
|
||||
# Process top predictions
|
||||
for j, hypo in enumerate(hypos[i][:self.args.nbest]):
|
||||
hypo_tokens, hypo_str, alignment = \
|
||||
utils.post_process_prediction(
|
||||
hypo_tokens=hypo['tokens'].int(),
|
||||
src_str=src_str,
|
||||
alignment=hypo['alignment'],
|
||||
align_dict=self.align_dict,
|
||||
tgt_dict=self.tgt_dict,
|
||||
remove_bpe=self.args.remove_bpe,
|
||||
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
|
||||
hypo_tokens = hypo['tokens'].int(),
|
||||
src_str = src_str,
|
||||
alignment = hypo['alignment'],
|
||||
align_dict = self.align_dict,
|
||||
tgt_dict = self.tgt_dict,
|
||||
remove_bpe = self.args.remove_bpe,
|
||||
extra_symbols_to_ignore = self._get_symbols_to_strip_from_output(self.generator),
|
||||
)
|
||||
|
||||
if not self.args.quiet:
|
||||
score = convert_base_e_to_base_2(hypo["score"])
|
||||
message_list.append('H-{}\t{}\t{}'.format(
|
||||
sample_id, score, hypo_str))
|
||||
if self.args.decode_hypothesis:
|
||||
detok_hypo_str = self._decode(hypo_str)
|
||||
message_list.append('D-{}\t{}\t{}'.format(
|
||||
sample_id, hypo['score'], detok_hypo_str))
|
||||
else:
|
||||
message_list.append('H-{}\t{}\t{}'.format(
|
||||
sample_id, hypo['score'], hypo_str))
|
||||
sample_id, score, detok_hypo_str))
|
||||
message_list.append('P-{}\t{}'.format(
|
||||
sample_id, ' '.join(
|
||||
map(
|
||||
lambda x: '{:.4f}'.format(x),
|
||||
hypo['positional_scores'].tolist(),
|
||||
convert_base_e_to_base_2(hypo['positional_scores']).tolist(),
|
||||
))))
|
||||
|
||||
if self.args.print_alignment:
|
||||
message_list.append('A-{}\t{}'.format(
|
||||
sample_id, ' '.join([
|
||||
'{}-{}'.format(src_idx, tgt_idx)
|
||||
for src_idx, tgt_idx in alignment
|
||||
])))
|
||||
|
||||
if self.args.print_step:
|
||||
message_list.append('I-{}\t{}'.format(
|
||||
sample_id, hypo['steps']))
|
||||
|
||||
if getattr(self.args, 'retain_iter_history', False):
|
||||
message_list.append("\n".join([
|
||||
'E-{}_{}\t{}'.format(sample_id, step,
|
||||
utils.post_process_prediction(
|
||||
h['tokens'].int(),
|
||||
self.src_str, None, None,
|
||||
self.tgt_dict, None)[1])
|
||||
for step, h in enumerate(hypo['history'])
|
||||
]))
|
||||
|
||||
for step, h in enumerate(hypo['history']):
|
||||
_, h_str, _ = utils.post_process_prediction(
|
||||
hypo_tokens = h['tokens'].int(),
|
||||
src_str = self.src_str,
|
||||
alignment = None,
|
||||
align_dict = None,
|
||||
tgt_dict = self.tgt_dict,
|
||||
remove_bpe = None,
|
||||
)
|
||||
message_list.append('E-{}_{}\t{}'.format(sample_id, step, h_str))
|
||||
|
||||
# Score only the top hypothesis
|
||||
if has_target and j == 0:
|
||||
if (self.align_dict is not None or
|
||||
self.args.remove_bpe is not None):
|
||||
|
||||
# Convert back to tokens for evaluation with unk
|
||||
# replacement and/or without BPE
|
||||
target_tokens = self.tgt_dict.encode_line(
|
||||
target_str, add_if_not_exist=True)
|
||||
if hasattr(self.scorer, 'add_string'):
|
||||
self.message_queue.put((target_str, hypo_str))
|
||||
target_str, add_if_not_exist = True)
|
||||
hypo_tokens = self.tgt_dict.encode_line(
|
||||
detok_hypo_str, add_if_not_exist = True)
|
||||
if hasattr(self.scorer, "add_string"):
|
||||
self.message_queue.put((target_str, detok_hypo_str))
|
||||
else:
|
||||
self.message_queue.put((target_tokens, hypo_tokens))
|
||||
|
||||
self.message_queue.put((target_tokens, hypo_tokens))
|
||||
self.message_queue.put('\n'.join(message_list))
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
r = self.data_queue.get()
|
||||
if r == GENERATE_FINISHED:
|
||||
self.data_queue.put(POSTPROCESS_FINISHED)
|
||||
break
|
||||
elif r is POSTPROCESS_FINISHED:
|
||||
if r == GENERATE_FINISHED or r is POSTPROCESS_FINISHED:
|
||||
self.data_queue.put(POSTPROCESS_FINISHED)
|
||||
break
|
||||
else:
|
||||
|
@ -263,156 +300,182 @@ class PostProcess(Process):
|
|||
self.message_queue.join_thread()
|
||||
self.message_queue.join()
|
||||
|
||||
|
||||
original_add_generation_args = add_generation_args
|
||||
|
||||
|
||||
@replace(add_generation_args)
|
||||
def add_generation_args_v1(parser):
|
||||
group = original_add_generation_args(parser)
|
||||
# fmt: off
|
||||
group.add_argument(
|
||||
'--postprocess-workers',
|
||||
default=1,
|
||||
type=int,
|
||||
choices=range(1, 128, 1),
|
||||
metavar='N',
|
||||
help='number of worker for post process')
|
||||
group.add_argument(
|
||||
'--decode-hypothesis',
|
||||
action="store_true")
|
||||
# fmt: on
|
||||
|
||||
|
||||
@replace(main)
|
||||
def main_v1(args):
|
||||
assert args.path is not None, '--path required for generation!'
|
||||
assert not args.sampling or args.nbest == args.beam, \
|
||||
'--sampling requires --nbest to be equal to --beam'
|
||||
assert args.replace_unk is None or args.raw_text, \
|
||||
'--replace-unk requires a raw text dataset (--raw-text)'
|
||||
|
||||
def _main(args, output_file):
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
stream=output_file,
|
||||
)
|
||||
logger = logging.getLogger("fastseq.optimizer.fairseq.generate")
|
||||
utils.import_user_module(args)
|
||||
|
||||
if args.max_tokens is None and args.max_sentences is None:
|
||||
if args.max_tokens is None and args.batch_size is None:
|
||||
args.max_tokens = 12000
|
||||
print(args)
|
||||
logger.info(args)
|
||||
|
||||
# Fix seed for stochastic decoding
|
||||
if args.seed is not None and not args.no_seed_provided:
|
||||
np.random.seed(args.seed)
|
||||
utils.set_torch_seed(args.seed)
|
||||
use_cuda = torch.cuda.is_available() and not args.cpu
|
||||
|
||||
# Load dataset splits
|
||||
task = tasks.setup_task(args)
|
||||
task.load_dataset(args.gen_subset)
|
||||
|
||||
# Set dictionaries
|
||||
try:
|
||||
src_dict = getattr(task, 'source_dictionary', None)
|
||||
except NotImplementedError:
|
||||
src_dict = None
|
||||
tgt_dict = task.target_dictionary
|
||||
overrides = ast.literal_eval(args.model_overrides)
|
||||
|
||||
# Load ensemble
|
||||
print('| loading model(s) from {}'.format(args.path))
|
||||
models, model_args_ = checkpoint_utils.load_model_ensemble(
|
||||
args.path.split(':'),
|
||||
arg_overrides=eval(args.model_overrides),
|
||||
task=task,
|
||||
logger.info("loading model(s) from {}".format(args.path))
|
||||
models, _ = checkpoint_utils.load_model_ensemble(
|
||||
utils.split_paths(args.path),
|
||||
arg_overrides = overrides,
|
||||
task = task,
|
||||
suffix = getattr(args, "checkpoint_suffix", ""),
|
||||
strict = (args.checkpoint_shard_count == 1),
|
||||
num_shards = args.checkpoint_shard_count,
|
||||
)
|
||||
if args.lm_path is not None:
|
||||
overrides["data"] = args.data
|
||||
try:
|
||||
lms, _ = checkpoint_utils.load_model_ensemble(
|
||||
[args.lm_path],
|
||||
arg_overrides=overrides,
|
||||
task=None,
|
||||
)
|
||||
except:
|
||||
logger.warning("Failed to load language model! Please make sure that the language model dict is the same as target dict and is located in the data dir ({})".format(args.data))
|
||||
raise
|
||||
assert len(lms) == 1
|
||||
else:
|
||||
lms = [None]
|
||||
|
||||
# Optimize ensemble for generation
|
||||
for model in models:
|
||||
model.make_generation_fast_(
|
||||
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
|
||||
need_attn=args.print_alignment,
|
||||
)
|
||||
for model in chain(models, lms):
|
||||
if model is None:
|
||||
continue
|
||||
if args.fp16:
|
||||
model.half()
|
||||
if use_cuda:
|
||||
if use_cuda and not args.pipeline_model_parallel:
|
||||
model.cuda()
|
||||
model.prepare_for_inference_(args)
|
||||
|
||||
# Load dataset (possibly sharded)
|
||||
itr = task.get_batch_iterator(
|
||||
dataset=task.dataset(args.gen_subset),
|
||||
max_tokens=args.max_tokens,
|
||||
max_sentences=args.max_sentences,
|
||||
max_positions=utils.resolve_max_positions(
|
||||
dataset = task.dataset(args.gen_subset),
|
||||
max_tokens = args.max_tokens,
|
||||
max_sentences = args.batch_size,
|
||||
max_positions = utils.resolve_max_positions(
|
||||
task.max_positions(),
|
||||
*[model.max_positions() for model in models]),
|
||||
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=args.required_batch_size_multiple,
|
||||
num_shards=args.num_shards,
|
||||
shard_id=args.shard_id,
|
||||
num_workers=args.num_workers,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
ignore_invalid_inputs = args.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple = args.required_batch_size_multiple,
|
||||
num_shards = args.num_shards,
|
||||
shard_id = args.shard_id,
|
||||
num_workers = args.num_workers,
|
||||
data_buffer_size = args.data_buffer_size,
|
||||
).next_epoch_itr(shuffle = False)
|
||||
progress = progress_bar.progress_bar(
|
||||
itr,
|
||||
log_format = args.log_format,
|
||||
log_interval = args.log_interval,
|
||||
default_log_format = ("tqdm" if not args.no_progress_bar else "none"),
|
||||
)
|
||||
|
||||
# Initialize generator
|
||||
gen_timer = StopwatchMeter()
|
||||
generator = task.build_generator(args)
|
||||
|
||||
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": args.lm_weight}
|
||||
generator = task.build_generator(
|
||||
models, args, extra_gen_cls_kwargs = extra_gen_cls_kwargs
|
||||
)
|
||||
num_sentences = 0
|
||||
data_queue = Queue()
|
||||
message_queue = JoinableQueue()
|
||||
|
||||
p_list = []
|
||||
for i in range(args.postprocess_workers):
|
||||
p = PostProcess(args, task, data_queue, message_queue)
|
||||
for _ in range(args.postprocess_workers):
|
||||
p = PostProcess(args, task, data_queue, message_queue, generator)
|
||||
p_list.append(p)
|
||||
p.start()
|
||||
|
||||
io_process = IOProcess(args, task, message_queue)
|
||||
io_process = IOProcess(args, task, message_queue, output_file)
|
||||
io_process.start()
|
||||
|
||||
if args.use_el_attn:
|
||||
task.transpose_enc_dec_kv_proj(models)
|
||||
with progress_bar.build_progress_bar(args, itr) as t:
|
||||
wps_meter = TimeMeter()
|
||||
for sample in t:
|
||||
cpu_sample = sample
|
||||
if 'net_input' not in sample:
|
||||
continue
|
||||
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
||||
|
||||
wps_meter = TimeMeter()
|
||||
for sample in progress:
|
||||
cpu_sample = sample
|
||||
if 'net_input' not in sample:
|
||||
continue
|
||||
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
||||
|
||||
prefix_tokens = None
|
||||
if args.prefix_size > 0:
|
||||
prefix_tokens = sample['target'][:, :args.prefix_size]
|
||||
prefix_tokens = None
|
||||
if args.prefix_size > 0:
|
||||
prefix_tokens = sample['target'][:, :args.prefix_size]
|
||||
|
||||
gen_timer.start()
|
||||
try:
|
||||
hypos = task.inference_step(
|
||||
generator, models, sample, prefix_tokens)
|
||||
except:
|
||||
logging.exception(sys.exc_info()[0])
|
||||
for p in p_list:
|
||||
p.terminate()
|
||||
io_process.terminate()
|
||||
data_queue.close()
|
||||
message_queue.close()
|
||||
sys.exit(1)
|
||||
|
||||
num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
|
||||
gen_timer.stop(num_generated_tokens)
|
||||
|
||||
hypos = [h[:args.nbest] for h in hypos]
|
||||
hypos = move_to_cpu(hypos) if use_cuda else hypos
|
||||
data_queue.put((cpu_sample, hypos))
|
||||
|
||||
wps_meter.update(num_generated_tokens)
|
||||
t.log({'wps': round(wps_meter.avg)})
|
||||
num_sentences += cpu_sample['nsentences']
|
||||
constraints = None
|
||||
if "constraints" in sample:
|
||||
constraints = sample["constraints"]
|
||||
|
||||
gen_timer.start()
|
||||
try:
|
||||
hypos = task.inference_step(
|
||||
generator, models, sample, prefix_tokens, constraints)
|
||||
except:
|
||||
logging.exception(sys.exc_info()[0])
|
||||
for p in p_list:
|
||||
p.terminate()
|
||||
io_process.terminate()
|
||||
data_queue.close()
|
||||
message_queue.close()
|
||||
sys.exit(1)
|
||||
num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
|
||||
gen_timer.stop(num_generated_tokens)
|
||||
hypos = [h[:args.nbest] for h in hypos]
|
||||
hypos = move_to_cpu(hypos) if use_cuda else hypos
|
||||
data_queue.put((cpu_sample, hypos))
|
||||
wps_meter.update(num_generated_tokens)
|
||||
progress.log({'wps': round(wps_meter.avg)})
|
||||
num_sentences += (
|
||||
cpu_sample['nsentences'] if "nsentences" in cpu_sample else cpu_sample["id"].numel()
|
||||
)
|
||||
|
||||
data_queue.put(GENERATE_FINISHED)
|
||||
for p in p_list:
|
||||
p.join()
|
||||
|
||||
sent_throught = num_sentences / gen_timer.sum if num_sentences > 0 else 0
|
||||
tokens_throught = 1. / gen_timer.avg if num_sentences > 0 else 0
|
||||
|
||||
message_queue.put(
|
||||
'| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'. # pylint: disable=line-too-long
|
||||
format(num_sentences, gen_timer.n, gen_timer.sum, sent_throught,
|
||||
tokens_throught))
|
||||
|
||||
message_queue.put(GENERATE_FINISHED)
|
||||
io_process.join()
|
||||
|
||||
sent_through = num_sentences / gen_timer.sum if num_sentences > 0 else 0
|
||||
tokens_through = 1. / gen_timer.avg if num_sentences > 0 else 0
|
||||
logger.info("NOTE: hypothesis and token scores are output in base 2")
|
||||
logger.info(
|
||||
"Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format(
|
||||
num_sentences,
|
||||
gen_timer.n,
|
||||
gen_timer.sum,
|
||||
sent_through,
|
||||
tokens_through,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
@replace(main)
|
||||
def main_v2(args):
|
||||
assert args.path is not None, '--path required for generation!'
|
||||
assert (
|
||||
not args.sampling or args.nbest == args.beam
|
||||
), "--sampling requires --nbest to be equal to --beam"
|
||||
assert (
|
||||
args.replace_unk is None or args.dataset_impl == "raw"
|
||||
), "--replace-unk requires a raw text dataset (--dataset-impl=raw)"
|
||||
|
||||
if args.results_path is not None:
|
||||
os.makedirs(args.results_path, exist_ok = True)
|
||||
output_path = os.path.join(
|
||||
args.results_path, "generate-{}.txt".format(args.gen_subset)
|
||||
)
|
||||
with open(output_path, "w", buffering = 1, encoding = "utf-8") as h:
|
||||
return _main(args, h)
|
||||
else:
|
||||
return _main(args, sys.stdout)
|
||||
|
||||
|
|
|
@ -84,6 +84,12 @@ BART_MODEL_URLS[
|
|||
BART_MODEL_URLS[
|
||||
'bart.large.xsum'] = 'https://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz'
|
||||
|
||||
CNNDM_URL = 'https://fastseq.blob.core.windows.net/data/tasks/cnn_dm.128/len-1024.bin'
|
||||
CACHED_CNNDM_DATA_DIR = os.path.join(FASTSEQ_CACHE_DIR, 'cnn_dm.128/len-1024.bin')
|
||||
|
||||
CNNDM_RAW_URL = 'https://fastseq.blob.core.windows.net/data/tasks/unit_tests'
|
||||
CACHED_CNNDM_RAW_DATA_DIR = os.path.join(FASTSEQ_CACHE_DIR, 'raw_cnndm_data')
|
||||
|
||||
CACHED_BART_MODEL_DIR = os.path.join(FASTSEQ_CACHE_DIR, 'fairseq_bart_models')
|
||||
|
||||
CACHED_BART_MODEL_PATHS = {}
|
||||
|
|
|
@ -14,17 +14,27 @@ from fairseq import options
|
|||
def parse_additional_args():
|
||||
parser = options.get_generation_parser()
|
||||
parser.add_argument(
|
||||
'--use-el-attn',
|
||||
'--use_el_attn',
|
||||
action='store_true',
|
||||
help='Use Efficient Lossless Attention optimization ? ')
|
||||
parser.add_argument(
|
||||
'--postprocess_workers',
|
||||
default=1,
|
||||
type=int,
|
||||
choices=range(1, 128, 1),
|
||||
metavar='N',
|
||||
help='number of worker for post process')
|
||||
parser.add_argument(
|
||||
'--decode_hypothesis',
|
||||
action="store_true")
|
||||
args = options.parse_args_and_arch(parser)
|
||||
return args
|
||||
|
||||
def cli_main():
|
||||
os.environ['USE_EL_ATTN'] = '1' if '--use-el-attn' in sys.argv else '0'
|
||||
from fastseq.optimizer.fairseq.generate import main_v1 # pylint: disable=import-outside-toplevel
|
||||
from fastseq.optimizer.fairseq.generate import main_v2 # pylint: disable=import-outside-toplevel
|
||||
args = parse_additional_args()
|
||||
main_v1(args)
|
||||
main_v2(args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli_main()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
yapf >= 0.30.0
|
||||
torch >= 1.4.0
|
||||
fairseq == 0.9.0
|
||||
fairseq == 0.10.2
|
||||
transformers == 3.0.2
|
||||
absl-py >= 0.9.0
|
||||
filelock >= 3.0.12
|
||||
|
|
6
setup.py
6
setup.py
|
@ -4,9 +4,9 @@
|
|||
from setuptools import find_packages, setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
FASTSEQ_VERSION = '0.0.4'
|
||||
MIN_FAIRSEQ_VERSION = '0.9.0'
|
||||
MAX_FAIRSEQ_VERSION = '0.9.0'
|
||||
FASTSEQ_VERSION = '0.2.0'
|
||||
MIN_FAIRSEQ_VERSION = '0.10.0'
|
||||
MAX_FAIRSEQ_VERSION = '0.10.2'
|
||||
MIN_TRANSFORMERS_VERSION = '3.0.2'
|
||||
MAX_TRANSFORMER_VERSION = '3.0.2'
|
||||
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Test the optimizations on FairSeq to make sure the changes do not affect the
|
||||
model accuracy.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
from absl.testing import absltest, parameterized
|
||||
from fairseq.models.bart.model import BARTModel
|
||||
|
||||
from fastseq.logging import get_logger
|
||||
|
||||
from fastseq.utils.file_utils import decompress_file, make_dirs, wget
|
||||
from fastseq.utils.test_utils import (BART_MODEL_URLS, CACHED_BART_MODEL_DIR,
|
||||
CACHED_BART_MODEL_PATHS, CNNDM_URL, CACHED_CNNDM_DATA_DIR,
|
||||
fastseq_test_main, TestCaseBase)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class FairseqGenerateCLITest(TestCaseBase):
|
||||
"""Test the optimizations on FairSeq
|
||||
|
||||
`bart.large.cnn` model is used for benchmarking. If it does not exist, it
|
||||
will be downloaded first. As the the model is big, it will take a while to
|
||||
download. Once downloaded, it will be cached for future usage.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""set up the test environment"""
|
||||
|
||||
super(FairseqGenerateCLITest, self).setUp()
|
||||
# TODO: create a dummy model instead of loading a large-size model.
|
||||
if not os.path.exists(CACHED_BART_MODEL_PATHS['bart.large.cnn']):
|
||||
make_dirs(CACHED_BART_MODEL_DIR, exist_ok=True)
|
||||
tar_model_path = os.path.join(CACHED_BART_MODEL_DIR,
|
||||
'bart.large.cnn.tar.gz')
|
||||
with open(tar_model_path, 'xb') as tar_model_file:
|
||||
wget(BART_MODEL_URLS['bart.large.cnn'], tar_model_file)
|
||||
decompress_file(tar_model_path, CACHED_BART_MODEL_DIR)
|
||||
|
||||
self.source_path = CACHED_CNNDM_DATA_DIR
|
||||
make_dirs(self.source_path, exist_ok=True)
|
||||
file_list = ["dict.source.txt", "dict.target.txt", "valid.source-target.source.bin", "valid.source-target.target.bin", "valid.source-target.source.idx", "valid.source-target.target.idx"]
|
||||
for f in file_list:
|
||||
f_path = os.path.join(self.source_path, f)
|
||||
if not os.path.exists(f_path):
|
||||
with open(f_path, 'xb') as new_file:
|
||||
wget(os.path.join(CNNDM_URL, f), new_file)
|
||||
new_file.close()
|
||||
|
||||
self.bart_path = CACHED_BART_MODEL_PATHS['bart.large.cnn'] + '/model.pt'
|
||||
|
||||
@parameterized.named_parameters({
|
||||
'testcase_name': 'Normal',
|
||||
'beam_size': 4,
|
||||
'batch_size': 16,
|
||||
'lenpen': 2.0,
|
||||
'max_len_b': 140,
|
||||
'min_len': 55,
|
||||
'no_repeat_ngram_size': 3,
|
||||
})
|
||||
def test_generate_cli(self, beam_size, batch_size,
|
||||
lenpen, max_len_b, min_len,
|
||||
no_repeat_ngram_size):
|
||||
"""Test the command line interface for fastseq. Make sure the changes do not
|
||||
affect the model accuracy for beam search optimization and el attn optimization
|
||||
|
||||
Args:
|
||||
beam_size (int): beam size.
|
||||
batch_size (int): batch size.
|
||||
need_attn (bool): indicate if attention is needed.
|
||||
lenpen (float): length penalty, where <1.0 favors shorter, >1.0
|
||||
favors longer sentences.
|
||||
max_len_b (int): max length of generated text.
|
||||
min_len (int): min length of generated text.
|
||||
no_repeat_ngram_size (int): size of no repeat gram.
|
||||
"""
|
||||
options = ["--path", self.bart_path,
|
||||
"--task", "translation",
|
||||
"--batch-size", str(batch_size),
|
||||
"--gen-subset", "valid",
|
||||
"--truncate-source",
|
||||
"--bpe", "gpt2",
|
||||
"--beam", str(beam_size),
|
||||
"--num-workers", "4",
|
||||
"--min-len", str(min_len),
|
||||
"--max-len-b", str(max_len_b),
|
||||
"--no-repeat-ngram-size", str(no_repeat_ngram_size),
|
||||
"--lenpen", str(lenpen),
|
||||
"--skip-invalid-size-inputs-valid-test",
|
||||
self.source_path]
|
||||
fairseq_outs = subprocess.check_output(['fairseq-generate'] + options).decode("utf-8").split("\n")
|
||||
try:
|
||||
import fastseq
|
||||
except ImportError:
|
||||
logger.error("Failed to import fastseq")
|
||||
|
||||
# test beam search opt
|
||||
options.append("--decode-hypothesis")
|
||||
|
||||
fastseq_outs = subprocess.check_output(['fastseq-generate-for-fairseq'] + options).decode("utf-8").split("\n")
|
||||
# only compare decoded hypotheses
|
||||
fairseq_outs = [l.split() for l in fairseq_outs]
|
||||
fairseq_outs = [l for l in fairseq_outs if len(l) > 2 and l[0][0] is 'D']
|
||||
fastseq_outs = [l.split() for l in fastseq_outs]
|
||||
fastseq_outs = [l for l in fastseq_outs if len(l) > 2 and l[0][0] is 'D']
|
||||
assert len(fairseq_outs) == len(fastseq_outs)
|
||||
assert len(fairseq_outs) == 128
|
||||
for i, expected_out in enumerate(fairseq_outs):
|
||||
self.assertEqual(expected_out[2:], fastseq_outs[i][2:])
|
||||
|
||||
fastseq_outs = None
|
||||
|
||||
# test el attn opt
|
||||
options.append("--use-el-attn")
|
||||
try:
|
||||
fastseq_outs = subprocess.check_output(['fastseq-generate-for-fairseq'] + options).decode("utf-8").split("\n")
|
||||
except subprocess.CalledProcessError as error:
|
||||
print('Error code:', error.returncode, '. Output:', error.output.decode("utf-8"))
|
||||
# only compare decoded hypotheses
|
||||
fastseq_outs = [l.split() for l in fastseq_outs]
|
||||
fastseq_outs = [l for l in fastseq_outs if len(l) > 2 and l[0][0] is 'D']
|
||||
assert len(fairseq_outs) == len(fastseq_outs)
|
||||
assert len(fairseq_outs) == 128
|
||||
for i, expected_out in enumerate(fairseq_outs):
|
||||
self.assertEqual(expected_out[2:], fastseq_outs[i][2:])
|
||||
|
||||
if __name__ == "__main__":
|
||||
fastseq_test_main()
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -1,128 +0,0 @@
|
|||
French prosecutor says he is not aware of any video footage from on board the plane. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The publications said that they watched the video, which was found by a source close to the investigation. An official with France 's accident investigation agency , the BEA, said the agency is notaware of any such video.
|
||||
The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal accession was marked with a ceremony at The Hague, where the court is based. Israel and the United States opposed the Palestinians ' efforts to join the body. The ICC opened a preliminary examination into the situation in Palestinian territories in January.
|
||||
Amnesty International's annual report catalogs the use of state-sanctioned killing as a punitive measure across the globe. The organization found positive developments worldwide, with most regions seeming to show reductions in the number of executions. The U.S. remains one of the worst offenders for imposing capital punishment, with only Iran executing more people in 2014.
|
||||
Amnesty International releases its annual review of the death penalty worldwide. In Pakistan, the government lifted a six-year moratorium on the execution of civilians. In Indonesia, authorities announced plans to execute mainly drug traffickers. A sharp spike in death sentences recorded in 2014 -- up more than 500 on the previous year.
|
||||
Anne Frank died of typhus in a Nazi concentration camp at the age of 15. Researchers re-examined archives of the Red Cross, the International Training Service and the Bergen-Belsen Memorial. They concluded that Anne and Margot probably did not survive to March 1945 -- contradicting the date of death previously determined by Dutch authorities.
|
||||
A Duke student has admitted to hanging a noose from a tree, university officials say. The prestigious private school did n't identify the student, citing federal privacy laws. The student was identified during an investigation by campus police and student affairs. Officials are still trying to determine if other people were involved in the incident.
|
||||
The Rev. Robert H. Schuller died Thursday at age 88. He was the founder of the television ministry "Hour of Power" He was diagnosed with esophageal cancer in August 2013. His Crystal Cathedral megachurch is now owned by the Roman Catholic Church.
|
||||
Stray pooch in Washington State has used up at least three of her own after being hit by a car. The dog was apparently whacked on the head with a hammer and buried in a field. Four days after her apparent death, the dog managed to stagger to a nearby farm. She suffered a dislocated jaw, leg injuries and a caved-in sinus cavity.
|
||||
Mohammad Javad Zarif is the Iranian foreign minister. He has been John Kerry 's opposite number in securing a breakthrough in nuclear discussions. Zarif received a hero 's welcome as he arrived in Iran on a sunny Friday morning. But there are some facts about Zarif that are less well-known.
|
||||
Bob Barker returns to host The Price Is Right for the first time in eight years. Barker hosted the TV game show for 35 years before stepping down in 2007. Looking spry at 91, Barker handled the first price-guessing game of the show before turning hosting duties over to Drew Carey.
|
||||
Trey Moses asked Ellie Meredith to be his prom date. Trey made the prom-posal in the gym during Ellie 's P.E. class. Ellie has struggled with friendships since elementary school. Trey is headed to play college ball next year at Ball State in Louisville, Kentucky.
|
||||
Michele Bachmann compares President Obama to the co-pilot of the doomed Germanwings flight. Bachmann: Obama is for the 300 million souls of the United States what Andreas Lubitz was for the 150 souls on the German Wings flight. Many comments on her Facebook page blasted the former representative.
|
||||
California is a breadbasket to the nation, growing more than a third of its vegetables and nearly two-thirds of its fruits and nuts. A strong dollar allows producers to import crops that may be withering under the absence of West Coast rain or other misfortunes elsewhere in the nation. Though fruits and vegetable prices fell in February , overall prices are expected to rise this year.
|
||||
Walmart is emerging as a bellwether for shifting public opinion on hot-button political issues that divide conservatives and liberals. Former Minnesota Gov. Tim Pawlenty says Walmart 's actions foreshadow where the Republican Party will need to move. The backlash over the religious freedom measures in Indiana and Arkansas this week is shining a bright light on the broader business community 's overwhelming support for workplace policies that promote gay equality.
|
||||
The five were exposed to Ebola in Sierra Leone in March. None of them developed the deadly virus. They are clinicians for Partners in Health, a Boston-based aid group. They all had contact with a colleague who was diagnosed with Ebola and is being treated at the National Institutes of Health.
|
||||
Andrew Getty, 47, appears to have died of natural causes, police say. An autopsy will be conducted, but there is no criminal investigation underway. Andrew Getty was found on his side near a bathroom in his home, KTLA reports. He was the grandson of oil tycoon J. Paul Getty, who died in 1976.
|
||||
Mike Pence is drawing huge heat for his controversial decision to sign a religious freedom law last week. John Avlon: The bill was Pence 's way of shoring up his street cred among ultraconservatives. He says there is no way a Republican can get through the pending primary without denouncing LGBT rights. Avlon says the issue of LGBT rights will turn numerous Americans into single issue voters.
|
||||
Maysak gained super typhoon status just a few days ago. It has since lost a lot of steam as it has spun west in the Pacific Ocean. It boasts steady winds of more than 70 mph -LRB- 115 kph -RRB- and gusts up to 90 mph. It 's expected to make landfall Sunday morning on the southeastern coast.
|
||||
Louis Jordan, 37, left Conway, South Carolina, to fish in the ocean. A storm capsized his boat and broke his mast, so he couldn't fix it right away. After his food and water ran out, it became an issue of survival. The boat capsized two more times before he was rescued, according to Jordan.
|
||||
Paul Walker died in November 2013 at the age of 40 after a car crash. The actor was on break from filming `` Furious 7 '' at the time of the fiery accident. The script was rewritten and special effects were used to finish scenes, with Walker 's brothers serving as body doubles. There are scenes that will resonate with the audience -- including the ending.
|
||||
The U.S. and its negotiating partners reached a very strong framework agreement with Iran. Peter Bergen: The debate that has already begun will likely result in more heat than light. He says the objective has always been to structure an agreement so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. Bergen says the inspections provisions that are part of this agreement are designed to protect against any covert action.
|
||||
Mötley Crüe 's Vince Neil reminded us of the dangers of tackling The Star-Spangled Banner. Whitney Houston set the modern standard for the national anthem at Super Bowl XXV. Jimi Hendrix inflamed mainstream America with his psychedelic take on the anthem.
|
||||
Yahya Rashid, a UK national from northwest London, was detained at Luton airport on Tuesday. He 's been charged with engaging in conduct in preparation of acts of terrorism, police say. Rashid is due to appear in Westminster Magistrates ' Court on Wednesday.
|
||||
The total lunar eclipse started at 3:16 a.m. Pacific Daylight Time. People west of the Mississippi River will have the best view. Parts of South America, India, China and Russia also will be able to see the eclipse. The eclipse will only last four minutes and 43 seconds, NASA says.
|
||||
Memories Pizza in Walkerton, Indiana, at center of debate over state 's Religious Freedom Restoration Act. Owners said they 'd refuse to cater a same-sex couple 's wedding. Critics said new law would allow businesses to discriminate against gays and lesbians. But supporters rallied and donated more than $ 842,000 for business.
|
||||
Miracles do happen, though, and not just in Hollywood. Louis Jordan says that he set off on his 35-foot sailboat from South Carolina in late January. Jose Salvador Alvarenga says his journey began in Paredon Viejo, Mexico, in late 2012. Ron Ingraham is n't one of those people -- he 's a fisherman who survived at sea.
|
||||
Iranian authorities imposed the ban on women attending men 's sports events after the revolution in 1979. The ban was reinstated in 2005 after the more hard-line Mahmoud Ahmadinejad came to power. FIFA President Sepp Blatter called on Iran last month to end its ban.
|
||||
Israeli Prime Minister Benjamin Netanyahu criticizes the Iran nuclear deal. He says he sees better options than "this bad deal or war" Democrats and Republicans spar over the framework announced last week to lift Western sanctions. GOP contenders for the 2016 presidential nomination lambasted it as giving Iran too much flexibility.
|
||||
A trip to a former heavyweight champ 's gaudy, abandoned mansion. The tallest and fastest '' giga-coaster '' in the world. A dramatic interview with a famed spiritual leader. A professor of physics at a British university asked 100 people to create a composite with facial features.
|
||||
Easter celebrates the completion of Christ 's mission of salvation in the Crucifixion and Resurrection. It does n't fall on the same day every year but shifts around in spring depending upon cosmic events. A blood moon appeared in the sky early Saturday -- right between Good Friday and Easter Sunday and during Passover.
|
||||
Police in Indian city of Malegaon are requiring identity cards for cattle owners. Cows are considered holy and revered by that state 's majority Hindu population. The Maharashtra Animal Preservation Bill now includes bans on the killing of bulls and bullocks. The consumption or sale of beef could now land you in prison for five years.
|
||||
Keonna Thomas is one of three women arrested this week on terror charges. Two New York women were also taken into custody. The FBI said Thomas purchased an electronic visa to Turkey on March 23. Turkey is known as the easiest place from which to enter Syria and join ISIS.
|
||||
Online activity is the alternative to traditional mainstream media in authoritarian countries. Lack of access to traditional print and broadcast media is driving force leading disaffected voices to post online. Jailing journalists is one thing, but being killed and doing little or nothing about it is another. Since 1992, 11 of the 11 journalists killed for their work online have been bloggers.
|
||||
Robert Lewis Burns Jr. was the original drummer in Southern rock band Lynyrd Skynyrd. Burns, 64, died after his car hit a mailbox and a tree in Cartersville, Georgia. He was not restrained at the time of the crash, a Georgia State Patrol spokesman says.
|
||||
Avril Lavigne was bedridden for five months after contracting Lyme disease. Lavigne believes that she was bitten by a tick last spring. The 30-year-old performer said she recuperated in her Ontario home. She is releasing a new single this month to support the 2015 Special Olympics.
|
||||
Two FBI agents were injured in a crash and the suspect was shot before being captured. Kevone Charleston, 36, is suspected of involvement in 32 commercial robberies dating to November 2013. Charleston was shot and wounded by FBI agents and task force officers, but his injuries are not life threatening.
|
||||
NCAA says it has no legal responsibility to ensure academic integrity of courses offered to student-athletes. NCAA says it does not have direct, day-to-day, operational control over member institutions like UNC. UNC scandal involved thousands of athletes who, over 18 years, were funneled into classes that never met.
|
||||
The Large Hadron Collider -LRB- LHC -RRB- is ready for action following a two-year shutdown. The purpose of the lengthy project is to recreate the conditions that existed moments after the Big Bang. The LHC generates up to 600 million particles per second , with a beam circulating for 10 hours.
|
||||
Fighting has killed hundreds of people in Yemen in less than two weeks. The Red Cross has cried out for a humanitarian ceasefire to let aid in. The U.N. Security Council discussed the humanitarian situation at Russia 's behest. Moscow submitted a draft resolution calling for a halt to the airstrikes.
|
||||
20 states have some version of the religious liberty law, and the legal controversies have grown. A tea used by a Brazilian faith is to them like wine used by Catholics at communion. The church became alarmed and cited how the federal government allows an exception for American Indians to use another illegal drug.
|
||||
Somalia-based Al-Shabaab has been behind a string of recent attacks in Kenya. The group is predominantly driven by the same radical interpretation of the Koran as al-Qaeda and ISIS. It is not clear whether it will switch allegiances to ISIS.
|
||||
Cassandra C. is in remission after nearly six months of forced chemo treatments. A Connecticut juvenile court judge issued a written decision denying a motion to let the teen go home. The 17-year-old is in temporary custody of the state for the time being, her attorney says. Cassandra was diagnosed with Hodgkin lymphoma in September.
|
||||
Mark Ronson's "Uptown Funk!" is the longest-leading Billboard Hot 100 of the 2010s. It's also just the 10th single in the Hot 100 's entire history to spend at least 13 weeks at No. 1. `` Funk '' is just three weeks from potentially tying `` One Sweet Day '' for the record.
|
||||
Thai Smile unveils colorful new livery featuring Jake , Finn and the beloved Princess Bubblegum sprawled across an Airbus A320. The interior of the plane also has an Adventure Time theme with overhead bins, head rests and even air sickness bags covered in the faces of characters from the show. The inaugural Thai Smile Adventure Time flight takes place on April 4 , heading from Bangkok to Phuket.
|
||||
The temperature was recorded at Argentina 's Esperanza Base on the northern tip of the Antarctica Peninsula. The World Meteorological Organization is in the process of setting up an international ad-hoc committee. The committee will examine the equipment used to measure the temperature, whether it was in good working order.
|
||||
Kim Ki-Jong is charged with attempted murder. He is also charged with assaulting a foreign envoy and business obstruction. U.S. Ambassador Mark Lippert was stabbed March 5 during an event in Seoul. Police say Kim stabbed him because he opposed the joint South Korean-U.S.-North Korean military drills.
|
||||
Reaching a good, solid agreement with Iran is a worthy, desirable goal. But the process has unfolded under the destructive influence of political considerations, says Peter Bergen. Bergen: Obama has a huge political stake in these negotiations. He says the notion that Obama is not handling the Iranian threat effectively is contributing to a new war in Yemen.
|
||||
Deion Sanders calls out his son on Twitter. Deion Sanders Jr. is a wide receiver at Southern Methodist University. His Twitter timeline is a mix of biblical verses, motivational quotes and references to sports, cars, school and Balenciaga shoes. He also has gone on record with his love for hood doughnuts.
|
||||
Blue Bell ice cream has temporarily shut down one of its manufacturing plants. Public health officials warned consumers Friday not to eat any Blue Bell-branded products made at the plant. That includes 3-ounce servings of Blue Bell ice Cream from this plant that went to institutions. Listeria monocytogenes was recently found in a cup of ice cream recovered from the hospital.
|
||||
Former banker Rurik Jutting, 29, charged with two counts of murder in Hong Kong. Court hearing to determine whether there was enough evidence to proceed to trial adjourned until May. Police found bodies of two women at J Residence in Wan Chai last November. One woman was lying on the floor with cuts to her neck and buttocks. Another was stuffed inside a suitcase on the balcony.
|
||||
"Furious 7 '' is getting the widest release in Universal 's history. The final film featuring the late Paul Walker is opening around the globe this weekend. The movie enjoys massive awareness and interest, due to both the popularity of the street-racing series and Walker 's death.
|
||||
Manning is serving a 35-year prison sentence for leaking thousands of classified documents. She said she will be using a voice phone to dictate her tweets to communications firm Fitzgibbon Media. The former Army intelligence analyst was convicted of stealing and disseminating 750,000 pages of documents and videos.
|
||||
Can UVA, Phi Kappa Psi or any of the other fraternities on campus sue for defamation? The Virginia Supreme Court said in Jordan v. Kollman that the elements of libel are -LRB- 1 -RRB- publication of an actionable statement. A private person suing for defamation must establish that the defendant has published a false factual statement.
|
||||
Defense Minister Gen Nakatani told the Diet that his jets had never come across any UFOs from outer space. He was responding to a query from flamboyant former wrestler-turned-lawmaker Antonio Inoki. Inoki also claims to have seen a UFO with his own eyes, but admitted he did n't know personally if aliens existed.
|
||||
The FBI has confirmed that one of its most wanted terrorists, Marwan, was killed in the Philippines. Marwan was believed by the FBI to be a member of southeast Asian terror group Jemaah Islamiyah 's central command. The FBI had been offering a $ 5 million reward for information leading to Marwan's capture.
|
||||
Shibuya ward in Tokyo passes ordinance allowing same-sex couples some of the rights of married heterosexual couples. Proponents of marriage equality in socially conservative Japan say that the ward 's decision is a step in the right direction. A recent poll found that a slight majority at 52.4 % oppose gay marriage, but support amongst young adults in their 20s and 30s is as high as 70%.
|
||||
David Lynch has confirmed he will no longer direct the revival of Twin Peaks. The cult 1990s television show was set to return in 2016. Lynch broke the news about his departure in a series of tweets. He said he felt the network was not offering enough money to produce the show.
|
||||
A hot mic picked up Kentucky guard Andrew Harrison saying of Wisconsin's Frank Kaminsky. Harrison said his words were in jest and that he meant no disrespect to Kaminsky, who is white. Kaminsky said Sunday that he was over it and that nothing needs to be made out of it.
|
||||
The blast occurred at an oil storage facility Monday night after an oil leak. Five out of six people were injured by broken glass and have been sent to the hospital. The plant produces paraxylene -LRB- PX -RRB- which is used in the production of polyester films and fabrics.
|
||||
The EPA says there was a presence of methyl bromide in the unit where the family was staying. The use of the pesticide is restricted in the U.S. because of its acute toxicity. The EPA is working with local government agencies to investigate whether any environmental regulations or laws were violated.
|
||||
Rand Paul has tried to sell himself as a different type of Republican. He 's tried to brand himself as the GOP 's minority outreach candidate. A quick survey of Sen. Paul 's positions makes clear that he does not. The American people deserve better than Rand Paul.
|
||||
U.S. Department of Justice has named a new defendant in the war on drugs. The courier delivery service FedEx is charged in a 15-count indictment. Peter Bergen: Corporations can indeed be prosecuted like a person. He says FedEx argues that it is indeed a common carrier, performing the normal duties.
|
||||
For the 1960s, the end arrived with -- depending on your ideals and your tribe -- either the Rolling Stones ' Altamont fiasco in December 1969 or Richard Nixon 's 1972 re-election. The end of a TV series brings with it some risk. Expect a number of longtime characters -- Ken Cosgrove, Harry Crane, Joan Harris -- to look for an exit.
|
||||
Comedian Chris Rock documented three traffic stops in seven weeks. Many African-Americans have long bemoaned being pulled over for no apparent reason. Blacks are about 30 % more likely to be pulled over by police than whites. Actor Isaiah Washington urged Rock to '' #Adapt '' to avoid racial profiling.
|
||||
Boston native Mark Wahlberg will star in a film about the Boston Marathon bombing and the manhunt that followed. The film is being produced by CBS Films and will feature material researched and shot by CBS News program 60 Minutes. Fox announced in November that it will be making a film called '' Boston Strong '' about the event.
|
||||
Model Manuela Arbelaez accidentally reveals the correct answer to a guessing game for a new Hyundai Sonata. Host Drew Carey could n't stop laughing. Arbelarez was mortified , attempting to hide behind the display. But everything turned out OK, she tweeted later.
|
||||
Kenyans use social media to share the victims' stories, hopes and dreams. The hashtag # 147notjustanumber refers to the number of people killed at Garissa University College. Kenyan authorities have not released a list of the victims. The attack was the nation 's deadliest since the bombing of the U.S. Embassy in 1998.
|
||||
Investigators are not expected to return to the crash site, a French national police official says. The plane crashed March 24 in rugged terrain of the Alps about 6 miles from the town of Seyne-les-Alpes. The flight data recorder was found Thursday by a member of the recovery team.
|
||||
Carlos Colina, 32, arraigned on charges of assault and battery causing serious bodily injury and improper disposal of a body. Police were notified Saturday morning about a suspicious item along a walkway in Cambridge. Officers opened a duffel bag and found human remains. Surveillance video led them to an apartment building where more body parts were discovered.
|
||||
American Jennifer Stewart says Etihad Airways lost her most important baggage. Her 2-year-old pet cat, Felix, went missing on a flight from Abu Dhabi to New York. Stewart believes the cat 's plastic carrier was badly damaged during the flight or the transfer from the airplane to the pickup area.
|
||||
Executive producer Brian Grazer said the show will return for a fifth season of 17 episodes. The fourth season was streamed exclusively on Netflix in 2013 after Fox canceled the show several years before. It was not yet known if the full cast , including Jason Bateman , Michael Cera and Will Arnett will return.
|
||||
Easter is unique on the Christian calendar, a major point in the cycle of the religious year. Easter Triduum refers to the three days of Easter that begin with Good Friday, proceed through Holy Saturday, and conclude with Easter Sunday. Easter embraces the great mystery of resurrection , with its promise of transformation -- a shift from one form to another.
|
||||
Money gets 37 mentions in the New Testament, while gold gets 38 citations. The Jesus community had a common purse because they needed money to survive. Jesus and his disciples walked, wore what they had, slept outside or in stayed in friends ' homes. Rabbi Joshua Garroway: It was possible Jesus and followers received donations from supporters.
|
||||
The nation 's top stories will be unfolding Tuesday in courthouses and political arenas across the country. In Louisville, Kentucky, Sen. Rand Paul made the not-so-surprising announcement that he will run for president. In Chicago, voters will head to the polls in a very surprising runoff between Mayor Rahm Emanuel and challenger Jesus Garcia. In Ferguson, Missouri, the shadow of Michael Brown and the protests over his shooting by Officer Darren Wilson will loom large over the city 's elections.
|
||||
Ted Cruz hit the trail in Iowa for the first time as a presidential candidate last week. Cruz drew large crowds during his two-day swing across the state. He 's counting on Iowa, known for its vocal and active evangelical base, to propel him forward. Cruz was one of the loudest defenders of the religious freedom law in Indiana.
|
||||
Nina Dobrev announced she will be leaving the CW show at the end of this season. Many chastised the show 's producers for allowing the show to go on to a seventh season. Dobrev seemed to anticipate the pain , urging fans to hold on through the season finale.
|
||||
Zhou Yongkang is the highest-ranking Chinese Communist Party official ever to face corruption charges. Zhou was one of nine men who effectively ruled the country of more than 1.3 billion people. Zhou controlled police forces, spy agencies, court systems as well as prosecution offices across China. He was expelled from the Communist Party and arrested last December.
|
||||
Education minister Smriti Irani was visiting a FabIndia outlet in the tourist resort state of Goa. She discovered a surveillance camera pointed at the changing room, police said. Four employees of the store have been arrested, but its manager is still at large. The arrested staff have been charged with voyeurism and breach of privacy.
|
||||
The group included four children -- the oldest being 10 or 11 , a Turkish official says. The nine were arrested at the Turkey-Syria border, the Turkish military says. It did n't say why the group allegedly was trying to get into Syria. The British Foreign Office says it is aware of reports of the arrests.
|
||||
Kayahan was one of Turkey 's best-loved singers and songwriters. He was first diagnosed with cancer in 1990, the year he competed in the Eurovision Song Contest. The cancer returned in 2005 and then again in 2014. He died Friday in a hospital in Istanbul, five days after his 66th birthday.
|
||||
A nuclear submarine being repaired at a Russian shipyard has caught on fire. The submarine is in a dry dock and there is no ammunition on board. The fire presents no threat to people and the shipyard, a spokesman says. The sub had been undergoing repairs since November 2013.
|
||||
Mexican state oil company Pemex said 45 workers were injured in the blaze. Two of them are in serious condition, the company said. Authorities evacuated about 300 people from the Abkatun Permanente platform. At least 10 boats worked to battle the blaze for hours off the coast of Mexico.
|
||||
TV5Monde was gradually regaining control of its channels and social media outlets. The outage began around 8:45 p.m. Paris time Wednesday. The network said it was hacked by an Islamist group. There was no immediate claim of responsibility by ISIS or any other group.
|
||||
A group of armed assailants stormed into the attorney general 's office in Balkh province. Two police officers and a security guard of the provincial attorney general's office were among the dead. Most staff members and civilians have been rescued. An exchange of fire between Afghan security forces and the assailants is ongoing.
|
||||
Jurors found Dzhokhar Tsarnaev guilty of all 30 counts he faced in the Boston Marathon bombing trial. Seventeen of the 30 counts were capital charges, meaning he is eligible for the death penalty. The trial will next move into a penalty phase, where the jury will hear testimony and arguments from both sides. Jurors will be asked to weigh aggravating factors against mitigating factors.
|
||||
The fire broke out at the General Electric Appliance Park in Louisville, Kentucky. Video showed both smoke and bright orange flames. There were no reports of anyone injured or trapped. The park is large, such that 34 football fields could fit in one of its warehouses in the facility.
|
||||
A U.S. official says Washington is not putting a timeframe on a possible invasion of Mosul. Mosul has long been the big prize in the Iraqi government 's fight to defeat ISIS. It has also long been a source of embarrassment, considering how it fell last June.
|
||||
Most of the penalty amounts to forced spending on improving pipeline safety. On September 9, 2010, a section of PG&E pipeline exploded in San Bruno, killing eight people. The company says it has paid more than $ 500 million in claims to the victims and victims ' families.
|
||||
One person is killed in Fairdale, Illinois. The tornado cuts a 22-mile path through Ogle County. Hail stones the size of tennis balls plummet down on Ashton. The National Weather Service warns people to be on alert for severe weather on Friday. The storm took away a local favorite restaurant.
|
||||
Officer Michael Slager has been fired and charged with murder in the death of 50-year-old Walter Scott. A bystander 's cell phone video shows the five-year police veteran shooting at Scott eight times as Scott runs away. The officer initially said that he used a Taser on Scott , who, Slager said, tried to take the weapon.
|
||||
Lucknow police to use pepper-spraying drones to control unruly crowds. The miniature aircraft will be fitted with a camera and pepper spray. Each drone costs between $ 9,560 and $ 19,300. Views on the new measure are mixed, with some concerned about the suppression of freedom of speech.
|
||||
A female employee accused Xavier Morales , a supervisor within the agency , of assault after he made sexual advances at her. This is just the latest chapter for an organization embroiled in scandal over the past several months. Last month, two top-ranking officials were suspended following an incident at a White House command post.
|
||||
Richard Dysart played cranky senior partner Leland McKenzie in NBC's "L.A. Law" Dysart was nominated for Emmy for outstanding supporting actor in a drama series for four straight years. He was one of the few actors to appear in every episode of the long-running series. Dysart also played Harry Truman in the CBS telefilm "Day One"
|
||||
Vijay Chokalingam applied to medical school claiming to be African-American. He received only one admission offer, to St. Louis University 's School of Medicine. He claims African-Americans garner special privileges that are unavailable to whites or Asians. The statute of limitations on his act of fraud has expired, he says.
|
||||
Don McLean 's pop masterpiece "American Pie" sold for $1.2 million. Song is a hybrid of modern poetry and folk ballad, beer-hall chant and high-art rock. Song replaced Bob Dylan 's "The Times They Are A Changin '' as Peoples Almanac of the new decade.
|
||||
Anthony Ray Hinton was convicted of murder in the 1985 deaths of two Birmingham-area fast-food managers. A new trial was ordered in 2014 after firearms experts testified 12 years earlier that the revolver could not be matched to evidence in either case. The state then declined to re-prosecute the case. Hinton, 58, was 29 at the time of the killings.
|
||||
April 8 was also "Rex Manning Day" on Twitter. April 14 was the date of Titanic's doomed love affair. October 3 is Mean Girls Day. April 25 is Miss Congeniality Day. October 21 is Back to the Future Part II's date in the 1989 film.
|
||||
Mary Kay Letourneau Fualaau was a married 34-year-old teacher and mother of four in 1996 when she began an affair with her 13-year old student. She gave birth to her young lover 's child and went on to serve more than seven years in prison on charges related to their sexual relationship. The pair wed soon after she was released from prison in 2005 and are now the parents of two teen girls.
|
||||
Twisted Sister says its 2016 tour will be its last. Next year marks the band 's 40th anniversary. Band will play with a new drummer, Mike Portnoy of Adrenaline Mob. The band will also perform two shows in Pero 's honor : one at Las Vegas ' Hard Rock Hotel and Casino and the other in Sayreville, New Jersey.
|
||||
The Weinstein Co. decided to sell the film directly to Lifetime rather than book it into U.S. theaters. The critically-panned film opened last year 's Cannes Film Festival. The film focuses on a period in the early '60s when Monaco was involved in a stand-off over taxes with France.
|
||||
The lyrics to the famed Don McLean song sold for $ 1.2 million Tuesday morning at an auction held by Christie 's. McLean has said that the opening lines were inspired by the death of Buddy Holly. The song catapulted the former folk singer to headliner status. The record for a popular music manuscript is held by Bob Dylan's "Like a Rolling Stone"
|
||||
Kanye West has settled a lawsuit with a paparazzi photographer he assaulted. The photographer, Daniel Ramos, had filed the civil suit against West after the hip-hop star attacked him. West pleaded no contest last year to a misdemeanor count of battery over the scuffle. A judge sentenced him to two years ' probation , as well as anger management sessions.
|
||||
Scientologist John Travolta is not a fan of HBO's new documentary. The actor is one of the Church of Scientology 's most high-profile members. The HBO documentary is critical of the organization, which has close ties to the showbiz industry. He credited the church with helping him to survive the death of his teen son, Jett.
|
||||
Blues legend B.B. King was hospitalized for dehydration. dehydration was caused by his Type II diabetes, his daughter says. King, 89, has 30 Grammy nominations and was inducted into the Rock and Roll Hall of Fame in 1987. Last year, the bluesman suffered from dehydration and exhaustion after a show.
|
||||
Dzhokhar Tsarnaev is found guilty on all 30 counts he faced for the Boston Marathon bombings. He and his brother planted bombs at the marathon, setting off deadly explosions. Hundreds of people were wounded in the bombings and their aftermath. The verdict brings a mix of emotions, from triumphant vows to move forward to expressions of gratitude.
|
||||
Chinese TV host known for impromptu satire caught on camera cursing the late Chairman Mao Zedong. Bi Fujian was filmed at a dinner party singing a revolutionary song that eulogizes the Communist Party 's early years when he started going off script. The 75-second video clip was uploaded on Monday and has since been removed from video-sharing sites inside China. Bi later apologized, saying his personal speech had led to "grave social consequences"
|
||||
Tornado sirens blare in Kansas as several storms bring reports of twisters. A tornado may have touched down in the small town of Potosi , Missouri. More storms are expected in the Midwest, Mississippi River Valley, Tennessee River Valley and near the southern Great Lakes.
|
||||
Katie, a giraffe at the Dallas Zoo, gave birth to a not-so-little baby early Friday evening. There was no immediate word on the newborn 's gender or condition. The baby joins a sister, 4-year-old calf Jamie. Katie is one of the only giraffes who can stick her long tongue out on cue.
|
||||
The fifth season of HBO's Game of Thrones premieres Sunday. It will be the most high-profile premiere yet, airing simultaneously in 170 countries for the first time. The Stark daughters, Arya and Sansa, will be characters to watch this season. The show currently set to end after seven years.
|
||||
Heads of state from 35 countries in the Western Hemisphere have met every three years to discuss economic, social or political issues. Cuba has historically been the wrench in the diplomatic machinery. Some Latin American leaders threatened not to attend the Summit of the Americas if the United States and Canada did n't agree to invite President Raul Castro.
|
||||
President Obama is giving up enormous leverage in his nuclear deal with Iran, says Peter Bergen. The deal would hand Tehran billions of previously sanctioned funds, he says. Bergen: The deal lacks tough safeguards to stop Iran from cheating. The best predictor of Iran's future behavior is its past behavior, he adds.
|
||||
Martin O'Malley and Jim Webb are both toying with a presidential run. Both shared a stage at the Polk County Democrats Awards Dinner in Des Moines, Iowa. Webb is a decorated Vietnam War veteran and former senator from Virginia. In a March CNN/ORC poll -LRB- PDF -RRB- of national Democrats, only 1 % said O'Martin and Webb were their top choice.
|
||||
Hundreds of mourners gather at a South Carolina church for the funeral of Walter Scott. The father of four was fatally shot in the back by a police officer a week ago. Officer Michael Slager was swiftly charged with murder and faces life in prison or the death penalty. Scott 's family was missing for the private burial.
|
||||
The shootings are connected, authorities say. They began with what authorities believe was a domestic kidnapping incident. The suspect 's vehicle was spotted outside the Census Bureau in Suitland, Maryland. A guard apparently approached the vehicle and saw two people arguing . That guard was then shot at least once in the upper body.
|
||||
Iranian military observation aircraft flew within 50 yards of an armed U.S. Navy helicopter over the Persian Gulf this month. The incident sparked concern that top Iranian commanders might not be in full control of local forces. The Navy MH-60R armed helicopter was flying from the deck of the USS Carl Vinson.
|
||||
Liana Barrientos has been married 10 times in New York. Prosecutors say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the Bronx on Friday. She was arrested and charged for allegedly sneaking into the New York subway through an emergency exit.
|
||||
#UporDown is a trending question on social media thanks to a photo of a cat coming down some stairs. The picture was apparently uploaded on Imgur a few days ago. Some people are noting the apparent motion of the cat . Others are commenting about the construction of the stairs.
|
||||
Pope Francis has reawakened the faith of a lapsed Catholic, says CNN's Richard Quest. Quest met one of the Pope 's newly appointed cardinals, Cardinal Gerald Lacroix, in Quebec City. The 57-year-old compared the Pope's approach to Jesus, who preached about the love of God.
|
||||
Mullah Mohammed Omar is credited with founding the Taliban in the early 1990s. Omar all but disappeared after a U.S.-led bombing campaign routed the Taliban from Kabul. The Taliban has released written statements purportedly made by the leader-in-hiding. But years without any video or audio recordings of the fugitive have led to growing speculation that Omar may have died.
|
||||
Amnesty International report is calling for authorities to address the number of attacks on women 's rights activists in Afghanistan. The report examines the persecution of activists by the Taliban and tribal warlords. The brutal murder of Farkhunda, a young woman whose body was burnt and callously chucked into a river in Kabul, shocked the world.
|
||||
Nelly is charged with felony possession of drugs, simple possession of marijuana and possession of drug paraphernalia. A state trooper stopped the bus because it was not displaying U.S. Department of Transportation and International Fuel Tax Association stickers. The trooper was about to conduct an inspection of the bus when he noticed an odor of marijuana emitting from the vehicle. Two troopers then searched the bus, finding five colored crystal-type rocks and methamphetamine.
|
||||
Collection of the first six Star Wars '' movies will also include many special features. Some of the features give fans a rare glimpse behind the scenes of the saga. Sound designer Ben Burtt explains which animals were used to capture the alien sounds made by the Geonosians. Take a look at the video above to find out.
|
||||
Hillary Clinton is finally announcing her candidacy for the 2016 presidential election. Julian Zelizer says there is ample reason to be excited about Clinton 's run for the presidency. Zelizer: If Clinton puts together an effective campaign she could be unbeatable in the Democratic primaries as well as in the general election. He says Clinton will have to contend with doubts about her authenticity.
|
||||
Martin O'Malley told reporters in Iowa on Friday that inevitability -- a term bandied about regarding Democratic presidential frontrunner Hillary Clinton -- is not unbreakable. Clinton was considered inevitable to win the nomination in 2008 but ended up losing to Barack Obama. The former governor capped off his two-day trip to the first-in-the-nation caucus state with a speech to the Polk County Democrats in Des Moines.
|
||||
Ramalinga Raju is the former chairman of software services exporter Satyam Computers Services. Raju admitted inflating profits with fictitious assets and nonexistent cash. He and nine others were convicted of cheating, criminal conspiracy, breach of public trust. The case has been compared to the 2001 Enron Corp. scandal.
|
||||
Sweden is said to be the most generous nation on Earth for parental leave. Fathers have to share that leave with mothers to promote both parents to raise their children. Only 12 % of Swedish couples equally share the 480 days of leave. Photographer Johan Bavman is looking for a total of 60 fathers to photograph to culminate in an exhibition and book.
|
||||
Craig Hicks, 46, is charged in the deaths of three Muslim college students in Chapel Hill, North Carolina. Superior Court Judge Orlando Hudson Jr. ruled that Hicks ' case is death penalty qualified. The victims ' family members have called on authorities to investigate the slayings as a hate crime. Police said an ongoing neighbor dispute over parking might have been a factor.
|
||||
Colin Farrell, Vince Vaughn, Rachel McAdams and Taylor Kitsch star in the new season. The new season premieres June 21. The first season starred Matthew McConaughey and Woody Harrelson as two Louisiana State Police detectives investigating the death of a young woman.
|
|
@ -17,11 +17,9 @@ from fastseq.logging import get_logger
|
|||
from fastseq import config
|
||||
from fastseq.utils.file_utils import decompress_file, make_dirs, wget
|
||||
from fastseq.utils.test_utils import (BART_MODEL_URLS, CACHED_BART_MODEL_DIR,
|
||||
CACHED_BART_MODEL_PATHS,
|
||||
CACHED_BART_MODEL_PATHS, CNNDM_RAW_URL, CACHED_CNNDM_RAW_DATA_DIR,
|
||||
fastseq_test_main, TestCaseBase)
|
||||
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class FairseqBeamSearchOptimizerTest(TestCaseBase):
|
||||
|
@ -49,12 +47,22 @@ class FairseqBeamSearchOptimizerTest(TestCaseBase):
|
|||
CACHED_BART_MODEL_PATHS['bart.large.cnn'],
|
||||
checkpoint_file='model.pt')
|
||||
|
||||
self.source_path = 'tests/optimizer/fairseq/data/cnndm_128.txt'
|
||||
make_dirs(CACHED_CNNDM_RAW_DATA_DIR, exist_ok=True)
|
||||
self.source_path = os.path.join(CACHED_CNNDM_RAW_DATA_DIR, 'cnndm_128.txt')
|
||||
if not os.path.exists(self.source_path):
|
||||
with open(self.source_path, 'xb') as source_file:
|
||||
wget(os.path.join(CNNDM_RAW_URL, 'cnndm_128.txt'), source_file)
|
||||
source_file.close()
|
||||
|
||||
self.target_path = os.path.join(CACHED_CNNDM_RAW_DATA_DIR, 'expected_output.hypo')
|
||||
if not os.path.exists(self.target_path):
|
||||
with open(self.target_path, 'xb') as target_file:
|
||||
wget(os.path.join(CNNDM_RAW_URL, 'expected_output.hypo'), target_file)
|
||||
target_file.close()
|
||||
|
||||
# read the expected output.
|
||||
self.expected_output_path = 'tests/optimizer/fairseq/data/expected_output.hypo' # pylint: disable=line-too-long
|
||||
self.expected_outputs = []
|
||||
with open(self.expected_output_path, 'rt',
|
||||
with open(self.target_path, 'rt',
|
||||
encoding="utf-8") as expected_output_file:
|
||||
for line in expected_output_file:
|
||||
self.expected_outputs.append(line.strip())
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# Licensed under the MIT License.
|
||||
""" script for importing fairseq tests """
|
||||
|
||||
from fastseq.config import USE_EL_ATTN
|
||||
import glob
|
||||
import io
|
||||
import logging
|
||||
|
@ -36,8 +37,7 @@ class FairseqUnitTests(parameterized.TestCase):
|
|||
if os.path.isdir(FAIRSEQ_PATH):
|
||||
shutil.rmtree(FAIRSEQ_PATH)
|
||||
Repo.clone_from(FAIRSEQ_GIT_URL, FAIRSEQ_PATH, branch=version)
|
||||
pipmain(['install', 'git+https://github.com/pytorch/fairseq.git@' +
|
||||
version])
|
||||
pipmain(['install', '--editable', FAIRSEQ_PATH])
|
||||
original_pythonpath = os.environ[
|
||||
'PYTHONPATH'] if 'PYTHONPATH' in os.environ else ''
|
||||
os.environ['PYTHONPATH'] = FAIRSEQ_PATH + ':' + original_pythonpath
|
||||
|
@ -57,12 +57,14 @@ class FairseqUnitTests(parameterized.TestCase):
|
|||
]
|
||||
return suites
|
||||
|
||||
@parameterized.named_parameters({
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
'testcase_name': 'Normal',
|
||||
'without_fastseq_opt': False,
|
||||
'fairseq_version': 'v0.9.0',
|
||||
'fairseq_version': 'v0.10.2',
|
||||
'blocked_tests': [
|
||||
'test_binaries.py', 'test_bmuf.py', 'test_reproducibility.py']
|
||||
'test_binaries.py', 'test_bmuf.py', 'test_reproducibility.py',
|
||||
'test_sequence_generator.py', 'test_backtranslation_dataset.py']
|
||||
})
|
||||
def test_suites(self, without_fastseq_opt, fairseq_version, blocked_tests):
|
||||
""""run test suites"""
|
||||
|
|
|
@ -13,5 +13,8 @@ pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pyto
|
|||
rm -rf build/
|
||||
rm ngram_repeat_block_cuda*.so
|
||||
pip install --editable .
|
||||
python tests/run_fairseq_tests.py
|
||||
echo "******* Run Fairseq tests with Beam Search Optimization *******"
|
||||
USE_EL_ATTN=0 python tests/run_fairseq_tests.py
|
||||
echo "******* Run Fairseq tests with EL Attn Optimization *******"
|
||||
USE_EL_ATTN=1 python tests/run_fairseq_tests.py
|
||||
deactivate
|
||||
|
|
Загрузка…
Ссылка в новой задаче