Add m2m100 (#10236)
* m2m_100 * no layernorm_embedding * sinusoidal positional embeddings * update pos embeddings * add default config values * tokenizer * add conversion script * fix config * fix pos embed * remove _float_tensor * update tokenizer * update lang codes * handle lang codes * fix pos embeds * fix spm key * put embedding weights on device * remove qa and seq classification heads * fix convert script * lang codes pn one line * fix embeds * fix tokenizer * fix tokenizer * add fast tokenizer * style * M2M100MT => M2M100 * fix copyright, style * tokenizer converter * vocab file * remove fast tokenizer * fix embeds * fix tokenizer * fix tests * add tokenizer tests * add integration test * quality * fix model name * fix test * doc * doc * fix doc * add copied from statements * fix tokenizer tests * apply review suggestions * fix urls * fix shift_tokens_right * apply review suggestions * fix * fix doc * add lang code to id * remove unused function * update checkpoint names * fix copy * fix tokenizer * fix checkpoint names * fix merge issue * style
This commit is contained in:
Родитель
fd01104435
Коммит
f6e74a63ca
|
@ -217,6 +217,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
|||
1. **[LED](https://huggingface.co/transformers/model_doc/led.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[Longformer](https://huggingface.co/transformers/model_doc/longformer.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[LXMERT](https://huggingface.co/transformers/model_doc/lxmert.html)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
|
||||
1. **[M2M100](https://huggingface.co/transformers/model_doc/m2m_100.html)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
|
||||
1. **[MarianMT](https://huggingface.co/transformers/model_doc/marian.html)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
|
||||
1. **[MBart](https://huggingface.co/transformers/model_doc/mbart.html)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
|
||||
1. **[MBart-50](https://huggingface.co/transformers/model_doc/mbart.html)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
|
||||
|
|
|
@ -161,57 +161,61 @@ and conversion utilities for the following models:
|
|||
26. :doc:`LXMERT <model_doc/lxmert>` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality
|
||||
Encoder Representations from Transformers for Open-Domain Question Answering <https://arxiv.org/abs/1908.07490>`__
|
||||
by Hao Tan and Mohit Bansal.
|
||||
27. :doc:`MarianMT <model_doc/marian>` Machine translation models trained using `OPUS <http://opus.nlpl.eu/>`__ data by
|
||||
27. :doc:`M2M100 <model_doc/m2m_100>` (from Facebook) released with the paper `Beyond English-Centric Multilingual
|
||||
Machine Translation <https://arxiv.org/abs/2010.11125>`__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi
|
||||
Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman
|
||||
Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
|
||||
28. :doc:`MarianMT <model_doc/marian>` Machine translation models trained using `OPUS <http://opus.nlpl.eu/>`__ data by
|
||||
Jörg Tiedemann. The `Marian Framework <https://marian-nmt.github.io/>`__ is being developed by the Microsoft
|
||||
Translator Team.
|
||||
28. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
|
||||
29. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
|
||||
Neural Machine Translation <https://arxiv.org/abs/2001.08210>`__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li,
|
||||
Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
|
||||
29. :doc:`MBart-50 <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Translation with Extensible
|
||||
30. :doc:`MBart-50 <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Translation with Extensible
|
||||
Multilingual Pretraining and Finetuning <https://arxiv.org/abs/2008.00401>`__ by Yuqing Tang, Chau Tran, Xian Li,
|
||||
Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
|
||||
30. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
|
||||
31. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
|
||||
Pre-training for Language Understanding <https://arxiv.org/abs/2004.09297>`__ by Kaitao Song, Xu Tan, Tao Qin,
|
||||
Jianfeng Lu, Tie-Yan Liu.
|
||||
31. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
|
||||
32. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
|
||||
text-to-text transformer <https://arxiv.org/abs/2010.11934>`__ by Linting Xue, Noah Constant, Adam Roberts, Mihir
|
||||
Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
32. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
|
||||
33. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
|
||||
Gap-sentences for Abstractive Summarization <https://arxiv.org/abs/1912.08777>`__> by Jingqing Zhang, Yao Zhao,
|
||||
Mohammad Saleh and Peter J. Liu.
|
||||
33. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
|
||||
34. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
|
||||
Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan, Weizhen Qi,
|
||||
Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
34. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
|
||||
35. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
|
||||
Transformer <https://arxiv.org/abs/2001.04451>`__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
||||
35. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
|
||||
36. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
|
||||
Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar
|
||||
Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
36. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||
37. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
|
||||
Krishna, and Kurt W. Keutzer.
|
||||
37. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
38. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
|
||||
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
38. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
39. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
|
||||
Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
39. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
40. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
||||
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
40. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
41. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
|
||||
Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
41. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
42. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
||||
42. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
43. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
||||
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
43. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
44. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
||||
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
||||
Zettlemoyer and Veselin Stoyanov.
|
||||
44. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
45. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
|
||||
|
@ -276,6 +280,8 @@ TensorFlow and/or Flax.
|
|||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Longformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| M2M100 | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| MPNet | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Marian | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
|
@ -416,6 +422,7 @@ TensorFlow and/or Flax.
|
|||
model_doc/longformer
|
||||
model_doc/lxmert
|
||||
model_doc/marian
|
||||
model_doc/m2m_100
|
||||
model_doc/mbart
|
||||
model_doc/mobilebert
|
||||
model_doc/mpnet
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
..
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
M2M100
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The M2M100 model was proposed in `Beyond English-Centric Multilingual Machine Translation
|
||||
<https://arxiv.org/abs/2010.11125>`__ by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky,
|
||||
Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy
|
||||
Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Existing work in translation demonstrated the potential of massively multilingual machine translation by training a
|
||||
single model able to translate between any pair of languages. However, much of this work is English-Centric by training
|
||||
only on data which was translated from or to English. While this is supported by large sources of training data, it
|
||||
does not reflect translation needs worldwide. In this work, we create a true Many-to-Many multilingual translation
|
||||
model that can translate directly between any pair of 100 languages. We build and open source a training dataset that
|
||||
covers thousands of language directions with supervised data, created through large-scale mining. Then, we explore how
|
||||
to effectively increase model capacity through a combination of dense scaling and language-specific sparse parameters
|
||||
to create high quality models. Our focus on non-English-Centric models brings gains of more than 10 BLEU when directly
|
||||
translating between non-English directions while performing competitively to the best single systems of WMT. We
|
||||
open-source our scripts so that others may reproduce the data, evaluation, and final M2M-100 model.*
|
||||
|
||||
|
||||
Training and Generation
|
||||
_______________________________________________________________________________________________________________________
|
||||
|
||||
M2M100 is a multilingual encoder-decoder (seq-to-seq) model primarily intended for translation tasks. As the model is
|
||||
multilingual it expects the sequences in a certain format: A special language id token is used as prefix in both the
|
||||
source and target text. The source text format is :obj:`[lang_code] X [eos]`, where :obj:`lang_code` is source language
|
||||
id for source text and target language id for target text, with :obj:`X` being the source or target text.
|
||||
|
||||
- Supervised Training
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import M2M100Config, M2M100ForConditionalGeneration, M2M100Tokenizer
|
||||
|
||||
model = M2M100ForConditionalGeneration.from_pretrained('facebook/m2m100_418M')
|
||||
tokenizer = M2M100Tokenizer.from_pretrained('facebook/m2m100_418M', src_lang="en", tgt_lang="fr")
|
||||
|
||||
src_text = "Life is like a box of chocolates."
|
||||
tgt_lang = "La vie est comme une boîte de chocolat."
|
||||
|
||||
model_inputs = tokenizer(src_text, return_tensors="pt")
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(tgt_text, return_tensors="pt").input_ids
|
||||
|
||||
loss = model(**model_inputs, labels=labels) # forward pass
|
||||
|
||||
|
||||
- Generation
|
||||
|
||||
M2M100 uses the :obj:`eos_token_id` as the :obj:`decoder_start_token_id` for generation with the target language id
|
||||
being forced as the first generated token. To force the target language id as the first generated token, pass the
|
||||
`forced_bos_token_id` parameter to the `generate` method. The following example shows how to translate between
|
||||
Hindi to French and Chinese to English using the `facebook/m2m100_418M` checkpoint.
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
||||
|
||||
>>> hi_text = "जीवन एक चॉकलेट बॉक्स की तरह है।"
|
||||
>>> chinese_text = "生活就像一盒巧克力。"
|
||||
|
||||
>>> model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
|
||||
>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
|
||||
|
||||
>>> # translate Hindi to French
|
||||
>>> tokenizer.src_lang = "hi"
|
||||
>>> encoded_hi = tokenizer(hi_text, return_tensors="pt")
|
||||
>>> generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.get_lang_id("fr"))
|
||||
>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
||||
"La vie est comme une boîte de chocolat."
|
||||
|
||||
>>> # translate Chinese to English
|
||||
>>> tokenizer.src_lang = "ar_AR"
|
||||
>>> encoded_zh = tokenizer(chinese_text, return_tensors="pt")
|
||||
>>> generated_tokens = model.generate(**encoded_zh, forced_bos_token_id=tokenizer.get_lang_id("en"))
|
||||
>>> tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
||||
"Life is like a box of chocolate."
|
||||
|
||||
|
||||
M2M100Config
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.M2M100Config
|
||||
:members:
|
||||
|
||||
|
||||
M2M100Tokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.M2M100Tokenizer
|
||||
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||
create_token_type_ids_from_sequences, save_vocabulary
|
||||
|
||||
|
||||
M2M100Model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.M2M100Model
|
||||
:members: forward
|
||||
|
||||
|
||||
M2M100ForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.M2M100ForConditionalGeneration
|
||||
:members: forward
|
||||
|
||||
|
|
@ -365,6 +365,12 @@ For the full list, refer to `https://huggingface.co/models <https://huggingface.
|
|||
| | ``reformer-crime-and-punishment`` | | 6-layer, 256-hidden, 2-heads, 3M parameters |
|
||||
| | | | Trained on English text: Crime and Punishment novel by Fyodor Dostoyevsky. |
|
||||
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| M2M100 | ``facebook/m2m100_418M`` | | 24-layer, 1024-hidden, 16-heads, 418M parameters |
|
||||
| | | | multilingual machine translation model for 100 languages |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``facebook/m2m100_1.2B`` | | 48-layer, 1024-hidden, 16-heads, 1.2B parameters |
|
||||
| | | | multilingual machine translation model for 100 languages |
|
||||
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| MarianMT | ``Helsinki-NLP/opus-mt-{src}-{tgt}`` | | 12-layer, 512-hidden, 8-heads, ~74M parameter Machine translation models. Parameter counts vary depending on vocab size. |
|
||||
| | | | (see `model list <https://huggingface.co/Helsinki-NLP>`_) |
|
||||
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
|
|
|
@ -134,6 +134,7 @@ _import_structure = {
|
|||
"Wav2Vec2FeatureExtractor",
|
||||
"Wav2Vec2Processor",
|
||||
],
|
||||
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config", "M2M100Tokenizer"],
|
||||
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
|
||||
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
|
||||
"models.auto": [
|
||||
|
@ -385,6 +386,14 @@ if is_torch_available():
|
|||
"Wav2Vec2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.m2m_100"].extend(
|
||||
[
|
||||
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"M2M100ForConditionalGeneration",
|
||||
"M2M100Model",
|
||||
]
|
||||
)
|
||||
|
||||
_import_structure["models.convbert"].extend(
|
||||
[
|
||||
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
|
@ -1348,6 +1357,7 @@ if TYPE_CHECKING:
|
|||
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
|
||||
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
|
||||
from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer
|
||||
from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config, M2M100Tokenizer
|
||||
from .models.marian import MarianConfig
|
||||
from .models.mbart import MBartConfig
|
||||
from .models.mmbt import MMBTConfig
|
||||
|
@ -1768,6 +1778,7 @@ if TYPE_CHECKING:
|
|||
LxmertVisualFeatureEncoder,
|
||||
LxmertXLayer,
|
||||
)
|
||||
from .models.m2m_100 import M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST, M2M100ForConditionalGeneration, M2M100Model
|
||||
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
|
||||
from .models.mbart import (
|
||||
MBartForCausalLM,
|
||||
|
|
|
@ -45,6 +45,7 @@ from . import (
|
|||
led,
|
||||
longformer,
|
||||
lxmert,
|
||||
m2m_100,
|
||||
marian,
|
||||
mbart,
|
||||
mmbt,
|
||||
|
|
|
@ -45,6 +45,7 @@ from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE
|
|||
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
|
||||
from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||
from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig
|
||||
from ..m2m_100.configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
|
||||
from ..marian.configuration_marian import MarianConfig
|
||||
from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
|
||||
from ..mobilebert.configuration_mobilebert import MobileBertConfig
|
||||
|
@ -76,6 +77,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
|||
for pretrained_map in [
|
||||
# Add archive maps here
|
||||
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
LED_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
|
@ -121,6 +123,7 @@ CONFIG_MAPPING = OrderedDict(
|
|||
[
|
||||
# Add configs here
|
||||
("wav2vec2", Wav2Vec2Config),
|
||||
("m2m_100", M2M100Config),
|
||||
("convbert", ConvBertConfig),
|
||||
("led", LEDConfig),
|
||||
("blenderbot-small", BlenderbotSmallConfig),
|
||||
|
@ -172,6 +175,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||
[
|
||||
# Add full (and cased) model names here
|
||||
("wav2vec2", "Wav2Vec2"),
|
||||
("m2m_100", "M2M100"),
|
||||
("convbert", "ConvBERT"),
|
||||
("led", "LED"),
|
||||
("blenderbot-small", "BlenderbotSmall"),
|
||||
|
|
|
@ -158,6 +158,7 @@ from ..longformer.modeling_longformer import (
|
|||
LongformerModel,
|
||||
)
|
||||
from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
|
||||
from ..m2m_100.modeling_m2m_100 import M2M100ForConditionalGeneration, M2M100Model
|
||||
from ..marian.modeling_marian import MarianForCausalLM, MarianModel, MarianMTModel
|
||||
from ..mbart.modeling_mbart import (
|
||||
MBartForCausalLM,
|
||||
|
@ -283,6 +284,7 @@ from .configuration_auto import (
|
|||
LEDConfig,
|
||||
LongformerConfig,
|
||||
LxmertConfig,
|
||||
M2M100Config,
|
||||
MarianConfig,
|
||||
MBartConfig,
|
||||
MobileBertConfig,
|
||||
|
@ -314,6 +316,7 @@ MODEL_MAPPING = OrderedDict(
|
|||
[
|
||||
# Base model mapping
|
||||
(Wav2Vec2Config, Wav2Vec2Model),
|
||||
(M2M100Config, M2M100Model),
|
||||
(ConvBertConfig, ConvBertModel),
|
||||
(LEDConfig, LEDModel),
|
||||
(BlenderbotSmallConfig, BlenderbotSmallModel),
|
||||
|
@ -397,6 +400,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
|||
[
|
||||
# Model with LM heads mapping
|
||||
(Wav2Vec2Config, Wav2Vec2ForMaskedLM),
|
||||
(M2M100Config, M2M100ForConditionalGeneration),
|
||||
(ConvBertConfig, ConvBertForMaskedLM),
|
||||
(LEDConfig, LEDForConditionalGeneration),
|
||||
(BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration),
|
||||
|
@ -495,6 +499,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
|||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[
|
||||
# Model for Seq2Seq Causal LM mapping
|
||||
(M2M100Config, M2M100ForConditionalGeneration),
|
||||
(LEDConfig, LEDForConditionalGeneration),
|
||||
(BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration),
|
||||
(MT5Config, MT5ForConditionalGeneration),
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
|
||||
"tokenization_m2m_100": ["M2M100Tokenizer"],
|
||||
}
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_m2m_100"] = [
|
||||
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"M2M100ForConditionalGeneration",
|
||||
"M2M100Model",
|
||||
"M2M100PreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
|
||||
from .tokenization_m2m_100 import M2M100Tokenizer
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_m2m_100 import (
|
||||
M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
M2M100ForConditionalGeneration,
|
||||
M2M100Model,
|
||||
M2M100PreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
class _LazyModule(_BaseLazyModule):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
"""
|
||||
|
||||
__file__ = globals()["__file__"]
|
||||
__path__ = [os.path.dirname(__file__)]
|
||||
|
||||
def _get_module(self, module_name: str):
|
||||
return importlib.import_module("." + module_name, self.__name__)
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
|
|
@ -0,0 +1,165 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" M2M100 model configuration """
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/m2m100_418M": "https://huggingface.co/facebook/m2m100_418M/resolve/main/config.json",
|
||||
# See all M2M100 models at https://huggingface.co/models?filter=m2m_100
|
||||
}
|
||||
|
||||
|
||||
class M2M100Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.M2M100Model`. It is used to
|
||||
instantiate an M2M100 model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the M2M100 `m2m100_418M
|
||||
<https://huggingface.co/facebook/m2m100_418M>`__ architecture.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 50265):
|
||||
Vocabulary size of the M2M100 model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.M2M100Model` or
|
||||
d_model (:obj:`int`, `optional`, defaults to 1024):
|
||||
Dimensionality of the layers and the pooler layer.
|
||||
encoder_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of encoder layers.
|
||||
decoder_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of decoder layers.
|
||||
encoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
|
||||
dropout (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
classifier_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout ratio for classifier.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 1024):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
init_std (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The LayerDrop probability for the encoder. See the `LayerDrop paper <see
|
||||
https://arxiv.org/abs/1909.11556>`__ for more details.
|
||||
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
|
||||
https://arxiv.org/abs/1909.11556>`__ for more details.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import M2M100Model, M2M100Config
|
||||
|
||||
>>> # Initializing a M2M100 facebook/m2m100_418M style configuration
|
||||
>>> configuration = M2M100Config()
|
||||
|
||||
>>> # Initializing a model from the facebook/m2m100_418M style configuration
|
||||
>>> model = M2M100Model(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
"""
|
||||
model_type = "m2m_100"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=128112,
|
||||
max_position_embeddings=1024,
|
||||
encoder_layers=12,
|
||||
encoder_ffn_dim=4096,
|
||||
encoder_attention_heads=16,
|
||||
decoder_layers=12,
|
||||
decoder_ffn_dim=4096,
|
||||
decoder_attention_heads=16,
|
||||
encoder_layerdrop=0.05,
|
||||
decoder_layerdrop=0.05,
|
||||
use_cache=True,
|
||||
is_encoder_decoder=True,
|
||||
activation_function="relu",
|
||||
d_model=1024,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
activation_dropout=0.0,
|
||||
init_std=0.02,
|
||||
decoder_start_token_id=2,
|
||||
scale_embedding=True,
|
||||
gradient_checkpointing=False,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.d_model = d_model
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.decoder_layers = decoder_layers
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.activation_function = activation_function
|
||||
self.init_std = init_std
|
||||
self.encoder_layerdrop = encoder_layerdrop
|
||||
self.decoder_layerdrop = decoder_layerdrop
|
||||
self.use_cache = use_cache
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@property
|
||||
def hidden_size(self) -> int:
|
||||
return self.d_model
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import M2M100Config, M2M100ForConditionalGeneration
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = [
|
||||
"encoder.version",
|
||||
"decoder.version",
|
||||
"model.encoder.version",
|
||||
"model.decoder.version",
|
||||
"decoder.output_projection.weight",
|
||||
"_float_tensor",
|
||||
"encoder.embed_positions._float_tensor",
|
||||
"decoder.embed_positions._float_tensor",
|
||||
]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
def make_linear_from_emb(emb):
|
||||
vocab_size, emb_size = emb.weight.shape
|
||||
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
||||
lin_layer.weight.data = emb.weight.data
|
||||
return lin_layer
|
||||
|
||||
|
||||
def convert_fairseq_m2m100_checkpoint_from_disk(checkpoint_path):
|
||||
m2m_100 = torch.load(checkpoint_path, map_location="cpu")
|
||||
args = m2m_100["args"]
|
||||
state_dict = m2m_100["model"]
|
||||
remove_ignore_keys_(state_dict)
|
||||
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
|
||||
|
||||
config = M2M100Config(
|
||||
vocab_size=vocab_size,
|
||||
max_position_embeddings=1024,
|
||||
encoder_layers=args.encoder_layers,
|
||||
decoder_layers=args.decoder_layers,
|
||||
encoder_attention_heads=args.encoder_attention_heads,
|
||||
decoder_attention_heads=args.decoder_attention_heads,
|
||||
encoder_ffn_dim=args.encoder_ffn_embed_dim,
|
||||
decoder_ffn_dim=args.decoder_ffn_embed_dim,
|
||||
d_model=args.encoder_embed_dim,
|
||||
encoder_layerdrop=args.encoder_layerdrop,
|
||||
decoder_layerdrop=args.decoder_layerdrop,
|
||||
dropout=args.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
activation_dropout=args.activation_dropout,
|
||||
activation_function="relu",
|
||||
)
|
||||
|
||||
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
|
||||
model = M2M100ForConditionalGeneration(config)
|
||||
model.model.load_state_dict(state_dict)
|
||||
model.lm_head = make_linear_from_emb(model.model.shared)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("fairseq_path", type=str, help="path to a model.pt on local filesystem.")
|
||||
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
model = convert_fairseq_m2m100_checkpoint_from_disk(args.fairseq_pathß)
|
||||
model.save_pretrained(args.pytorch_dump_folder_path)
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,334 @@
|
|||
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for M2M100."""
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import sentencepiece
|
||||
|
||||
from ...tokenization_utils import BatchEncoding, PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SPIECE_UNDERLINE = "▁"
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"vocab_file": "vocab.json",
|
||||
"spm_file": "sentencepiece.bpe.model",
|
||||
"tokenizer_config_file": "tokenizer_config.json",
|
||||
}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"facebook/m2m100_418M": "https://huggingface.co/facebook/m2m100_418M/resolve/main/vocab.json",
|
||||
},
|
||||
"spm_file": {
|
||||
"facebook/m2m100_418M": "https://huggingface.co/facebook/m2m100_418M/resolve/main/sentencepiece.bpe.model",
|
||||
},
|
||||
"tokenizer_config_file": {
|
||||
"facebook/m2m100_418M": "https://huggingface.co/facebook/m2m100_418M/resolve/main/tokenizer_config.json",
|
||||
},
|
||||
}
|
||||
|
||||
ALL_M2M100_MODELS = ["facebook/m2m100_418M", "facebook/m2m100_1.2B"]
|
||||
SPM_URL = "https://huggingface.co/facebook/m2m100_418M/resolve/main/sentence.bpe.model"
|
||||
|
||||
# fmt: off
|
||||
FAIRSEQ_LANGUAGE_CODES = ["af", "am", "ar", "ast", "az", "ba", "be", "bg", "bn", "br", "bs", "ca", "ceb", "cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "ff", "fi", "fr", "fy", "ga", "gd", "gl", "gu", "ha", "he", "hi", "hr", "ht", "hu", "hy", "id", "ig", "ilo", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "lb", "lg", "ln", "lo", "lt", "lv", "mg", "mk", "ml", "mn", "mr", "ms", "my", "ne", "nl", "no", "ns", "oc", "or", "pa", "pl", "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl", "so", "sq", "sr", "ss", "su", "sv", "sw", "ta", "th", "tl", "tn", "tr", "uk", "ur", "uz", "vi", "wo", "xh", "yi", "yo", "zh", "zu"]
|
||||
# fmt: on
|
||||
|
||||
|
||||
class M2M100Tokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct an M2M100 tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`__.
|
||||
|
||||
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
|
||||
Users should refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (:obj:`str`):
|
||||
Path to the vocabulary file.
|
||||
spm_file (:obj:`str`):
|
||||
Path to `SentencePiece <https://github.com/google/sentencepiece>`__ file (generally has a .spm extension)
|
||||
that contains the vocabulary.
|
||||
src_lang (:obj:`str`, `optional`):
|
||||
A string representing the source language.
|
||||
tgt_lang (:obj:`str`, `optional`):
|
||||
A string representing the target language.
|
||||
eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
|
||||
The end of sequence token.
|
||||
sep_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
|
||||
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
||||
sequence classification or for a text and a question for question answering. It is also used as the last
|
||||
token of a sequence built with special tokens.
|
||||
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import M2M100Tokenizer
|
||||
>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M, src_lang="en", tgt_lang="ro")
|
||||
>>> src_text = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> model_inputs = tokenizer(src_text, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
|
||||
>>> # model(**model_inputs, labels=labels) should work
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
max_model_input_sizes = {m: 1024 for m in ALL_M2M100_MODELS}
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
prefix_tokens: List[int] = []
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
spm_file,
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
sep_token="</s>",
|
||||
pad_token="<pad>",
|
||||
unk_token="<unk>",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
sep_token=sep_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.encoder = load_json(vocab_file)
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.spm_file = spm_file
|
||||
self.sp_model = load_spm(spm_file)
|
||||
|
||||
self.encoder_size = len(self.encoder)
|
||||
|
||||
self.lang_code_to_token = {lang_code: f"__{lang_code}__" for lang_code in FAIRSEQ_LANGUAGE_CODES}
|
||||
|
||||
self.lang_token_to_id = {
|
||||
self.get_lang_token(lang_code): self.encoder_size + i for i, lang_code in enumerate(FAIRSEQ_LANGUAGE_CODES)
|
||||
}
|
||||
self.lang_code_to_id = {lang_code: self.encoder_size + i for i, lang_code in enumerate(FAIRSEQ_LANGUAGE_CODES)}
|
||||
self.id_to_lang_token = {v: k for k, v in self.lang_token_to_id.items()}
|
||||
self._additional_special_tokens = list(self.lang_token_to_id.keys())
|
||||
|
||||
self._src_lang = src_lang if src_lang is not None else "en"
|
||||
self.tgt_lang = tgt_lang
|
||||
self.cur_lang_id = self.get_lang_id(self._src_lang)
|
||||
self.set_src_lang_special_tokens(self._src_lang)
|
||||
|
||||
self.num_madeup_words = 8
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return len(self.encoder) + len(self.lang_token_to_id) + self.num_madeup_words
|
||||
|
||||
@property
|
||||
def src_lang(self) -> str:
|
||||
return self._src_lang
|
||||
|
||||
@src_lang.setter
|
||||
def src_lang(self, new_src_lang: str) -> None:
|
||||
self._src_lang = new_src_lang
|
||||
self.set_src_lang_special_tokens(self._src_lang)
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
return self.sp_model.EncodeAsPieces(text)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
if token in self.lang_token_to_id:
|
||||
return self.lang_token_to_id[token]
|
||||
return self.encoder.get(token, self.encoder[self.unk_token])
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
"""Converts an index (integer) in a token (str) using the decoder."""
|
||||
if index in self.id_to_lang_token:
|
||||
return self.id_to_lang_token[index]
|
||||
return self.decoder.get(index, self.unk_token)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
|
||||
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
|
||||
return out_string
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` method.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the token list is already formatted with special tokens for the model.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formatted with special tokens for the model."
|
||||
)
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
prefix_ones = [1] * len(self.prefix_tokens)
|
||||
suffix_ones = [1] * len(self.suffix_tokens)
|
||||
if token_ids_1 is None:
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. An MBART sequence has the following format, where ``X`` represents the sequence:
|
||||
|
||||
- ``input_ids`` (for encoder) ``X [eos, src_lang_code]``
|
||||
- ``decoder_input_ids``: (for decoder) ``X [eos, tgt_lang_code]``
|
||||
|
||||
BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
|
||||
separator.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return self.prefix_tokens + token_ids_0 + self.suffix_tokens
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
|
||||
|
||||
def get_vocab(self) -> Dict:
|
||||
vocab = self.encoder.copy()
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = self.__dict__.copy()
|
||||
state["sp_model"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, d: Dict) -> None:
|
||||
self.__dict__ = d
|
||||
self.sp_model = load_spm(self.spm_file)
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
save_dir = Path(save_directory)
|
||||
assert save_dir.is_dir(), f"{save_directory} should be a directory"
|
||||
vocab_save_path = save_dir / (
|
||||
(filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
|
||||
)
|
||||
spm_save_path = save_dir / (
|
||||
(filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["spm_file"]
|
||||
)
|
||||
|
||||
save_json(self.encoder, vocab_save_path)
|
||||
|
||||
if not spm_save_path.exists():
|
||||
copyfile(self.spm_file, spm_save_path)
|
||||
|
||||
return (str(vocab_save_path), str(spm_save_path))
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
src_lang: str = "en",
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
tgt_lang: str = "ro",
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
yield
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang: str) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||
lang_token = self.get_lang_token(src_lang)
|
||||
self.cur_lang_id = self.lang_token_to_id[lang_token]
|
||||
self.prefix_tokens = [self.cur_lang_id]
|
||||
self.suffix_tokens = [self.eos_token_id]
|
||||
|
||||
def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None:
|
||||
"""Reset the special tokens to the target language setting. No prefix and suffix=[eos, tgt_lang_code]."""
|
||||
lang_token = self.get_lang_token(tgt_lang)
|
||||
self.cur_lang_id = self.lang_token_to_id[lang_token]
|
||||
self.prefix_tokens = [self.cur_lang_id]
|
||||
self.suffix_tokens = [self.eos_token_id]
|
||||
|
||||
def get_lang_token(self, lang: str) -> str:
|
||||
return self.lang_code_to_token[lang]
|
||||
|
||||
def get_lang_id(self, lang: str) -> int:
|
||||
lang_token = self.get_lang_token(lang)
|
||||
return self.lang_token_to_id[lang_token]
|
||||
|
||||
|
||||
def load_spm(path: str) -> sentencepiece.SentencePieceProcessor:
|
||||
spm = sentencepiece.SentencePieceProcessor()
|
||||
spm.Load(str(path))
|
||||
return spm
|
||||
|
||||
|
||||
def load_json(path: str) -> Union[Dict, List]:
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_json(data, path: str) -> None:
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
|
@ -1593,6 +1593,27 @@ class LxmertXLayer:
|
|||
requires_pytorch(self)
|
||||
|
||||
|
||||
M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class M2M100ForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class M2M100Model:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class MarianForCausalLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
|
|
@ -0,0 +1,353 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch M2M100 model. """
|
||||
|
||||
|
||||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import M2M100Config, M2M100ForConditionalGeneration, M2M100Model, M2M100Tokenizer
|
||||
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Decoder, M2M100Encoder
|
||||
|
||||
|
||||
def prepare_m2m_100_inputs_dict(
|
||||
config,
|
||||
input_ids,
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.ne(config.pad_token_id)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
class M2M100ModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=4,
|
||||
hidden_act="relu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=20,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
||||
3,
|
||||
)
|
||||
input_ids[:, -1] = self.eos_token_id # Eos Token
|
||||
|
||||
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
config = M2M100Config(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
encoder_layers=self.num_hidden_layers,
|
||||
decoder_layers=self.num_hidden_layers,
|
||||
encoder_attention_heads=self.num_attention_heads,
|
||||
decoder_attention_heads=self.num_attention_heads,
|
||||
encoder_ffn_dim=self.intermediate_size,
|
||||
decoder_ffn_dim=self.intermediate_size,
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
inputs_dict = prepare_m2m_100_inputs_dict(config, input_ids, decoder_input_ids)
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||
model = M2M100Model(config=config).get_decoder().to(torch_device).eval()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
|
||||
|
||||
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
||||
model = M2M100Model(config=config).to(torch_device).eval()
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
encoder_last_hidden_state = outputs.encoder_last_hidden_state
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
encoder = model.get_encoder()
|
||||
encoder.save_pretrained(tmpdirname)
|
||||
encoder = M2M100Encoder.from_pretrained(tmpdirname).to(torch_device)
|
||||
|
||||
encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
|
||||
0
|
||||
]
|
||||
|
||||
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
decoder = model.get_decoder()
|
||||
decoder.save_pretrained(tmpdirname)
|
||||
decoder = M2M100Decoder.from_pretrained(tmpdirname).to(torch_device)
|
||||
|
||||
last_hidden_state_2 = decoder(
|
||||
input_ids=inputs_dict["decoder_input_ids"],
|
||||
attention_mask=inputs_dict["decoder_attention_mask"],
|
||||
encoder_hidden_states=encoder_last_hidden_state,
|
||||
encoder_attention_mask=inputs_dict["attention_mask"],
|
||||
)[0]
|
||||
|
||||
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
|
||||
|
||||
|
||||
@require_torch
|
||||
class M2M100ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
M2M100Model,
|
||||
M2M100ForConditionalGeneration,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (M2M100ForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = M2M100ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=M2M100Config)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_save_load_strict(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||
self.assertEqual(info["missing_keys"], [])
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
def test_encoder_decoder_model_standalone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in (M2M100Model, M2M100ForConditionalGeneration):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs["input_ids"]
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||
del inputs["input_ids"]
|
||||
inputs.pop("decoder_input_ids", None)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
if not self.is_encoder_decoder:
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
else:
|
||||
inputs["inputs_embeds"] = wte(encoder_input_ids)
|
||||
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
model = M2M100ForConditionalGeneration(config).eval().to(torch_device)
|
||||
if torch_device == "cuda":
|
||||
model.half()
|
||||
model.generate(input_ids, attention_mask=attention_mask)
|
||||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||
|
||||
|
||||
def _long_tensor(tok_lst):
|
||||
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
class M2M100ModelIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_tokenizer(self):
|
||||
return M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
|
||||
|
||||
def test_inference_no_head(self):
|
||||
model = M2M100Model.from_pretrained("facebook/m2m100_418M").to(torch_device)
|
||||
input_ids = _long_tensor([[128028, 98, 12, 30527, 2732, 159, 7755, 61904, 39144, 38, 2]])
|
||||
decoder_input_ids = _long_tensor([[2, 128028, 98, 12, 30527, 2732, 159, 7755, 61904, 39144, 38]])
|
||||
inputs_dict = prepare_m2m_100_inputs_dict(model.config, input_ids, decoder_input_ids)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
expected_shape = torch.Size((1, 11, 1024))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
# change to expected output here
|
||||
expected_slice = torch.tensor(
|
||||
[[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]], device=torch_device
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
def test_inference_head(self):
|
||||
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(torch_device)
|
||||
|
||||
# change to intended input
|
||||
input_ids = _long_tensor([[128028, 98, 12, 30527, 2732, 159, 7755, 61904, 39144, 38, 2]])
|
||||
decoder_input_ids = _long_tensor([[2, 128028, 98, 12, 30527, 2732, 159, 7755, 61904, 39144, 38]])
|
||||
inputs_dict = prepare_m2m_100_inputs_dict(model.config, input_ids, decoder_input_ids)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
expected_shape = torch.Size((1, 11, model.config.vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
# change to expected output here
|
||||
expected_slice = torch.tensor(
|
||||
[[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]], device=torch_device
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
|
||||
|
||||
def test_seq_to_seq_generation(self):
|
||||
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(torch_device)
|
||||
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en")
|
||||
|
||||
src_fr = [
|
||||
"L'affaire NSA souligne l'absence totale de débat sur le renseignement",
|
||||
"Selon moi, il y a deux niveaux de réponse de la part du gouvernement français.",
|
||||
"Lorsque François Hollande téléphone à Barack Obama ou quand le ministre des affaires étrangères Laurent Fabius convoque l'ambassadeur des Etats-Unis, ils réagissent à une vraie découverte, qui est celle de l'ampleur de la surveillance américaine sur l'ensemble des communications en France.",
|
||||
]
|
||||
|
||||
# The below article tests that we don't add any hypotheses outside of the top n_beams
|
||||
dct = tokenizer(src_fr, padding=True, return_tensors="pt")
|
||||
|
||||
hypotheses_batch = model.generate(
|
||||
input_ids=dct["input_ids"].to(torch_device),
|
||||
attention_mask=dct["attention_mask"].to(torch_device),
|
||||
num_beams=5,
|
||||
forced_bos_token_id=tokenizer.get_lang_id("en"),
|
||||
)
|
||||
|
||||
expected_en = [
|
||||
"The NSA case highlights the total absence of intelligence debate",
|
||||
"I think there are two levels of response from the French government.",
|
||||
"When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S. Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all communications in France.",
|
||||
]
|
||||
|
||||
generated = tokenizer.batch_decode(
|
||||
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||
)
|
||||
assert generated == expected_en
|
|
@ -0,0 +1,193 @@
|
|||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
|
||||
from transformers import M2M100Tokenizer, is_torch_available
|
||||
from transformers.file_utils import is_sentencepiece_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
||||
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from transformers.models.m2m_100.tokenization_m2m_100 import save_json, VOCAB_FILES_NAMES
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
if is_sentencepiece_available():
|
||||
SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.models.m2m_100.modeling_m2m_100 import shift_tokens_right
|
||||
|
||||
EN_CODE = 128022
|
||||
FR_CODE = 128028
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
class M2M100TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = M2M100Tokenizer
|
||||
test_rust_tokenizer = False
|
||||
test_seq2seq = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
vocab = ["</s>", "<unk>", "▁This", "▁is", "▁a", "▁t", "est", "\u0120", "<pad>"]
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
save_dir = Path(self.tmpdirname)
|
||||
save_json(vocab_tokens, save_dir / VOCAB_FILES_NAMES["vocab_file"])
|
||||
if not (save_dir / VOCAB_FILES_NAMES["spm_file"]).exists():
|
||||
copyfile(SAMPLE_SP, save_dir / VOCAB_FILES_NAMES["spm_file"])
|
||||
|
||||
tokenizer = M2M100Tokenizer.from_pretrained(self.tmpdirname)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return M2M100Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
return (
|
||||
"This is a test",
|
||||
"This is a test",
|
||||
)
|
||||
|
||||
@unittest.skip("Skip this test while all models are still to be uploaded.")
|
||||
def test_pretrained_model_lists(self):
|
||||
pass
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
tokens = tokenizer.tokenize("This is a test")
|
||||
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens),
|
||||
[2, 3, 4, 5, 6],
|
||||
)
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens([2, 3, 4, 5, 6])
|
||||
self.assertListEqual(back_tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||
|
||||
text = tokenizer.convert_tokens_to_string(tokens)
|
||||
self.assertEqual(text, "This is a test")
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class M2M100TokenizerIntegrationTest(unittest.TestCase):
|
||||
checkpoint_name = "facebook/m2m100_418M"
|
||||
src_text = [
|
||||
"In my opinion, there are two levels of response from the French government.",
|
||||
"NSA Affair Emphasizes Complete Lack of Debate on Intelligence",
|
||||
]
|
||||
tgt_text = [
|
||||
"Selon moi, il y a deux niveaux de réponse de la part du gouvernement français.",
|
||||
"L'affaire NSA souligne l'absence totale de débat sur le renseignement",
|
||||
]
|
||||
|
||||
# fmt: off
|
||||
expected_src_tokens = [EN_CODE, 593, 1949, 115781, 4, 71586, 4234, 60633, 126233, 432, 123808, 15592, 1197, 117132, 120618, 5, 2]
|
||||
# fmt: on
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokenizer: M2M100Tokenizer = M2M100Tokenizer.from_pretrained(
|
||||
cls.checkpoint_name, src_lang="en", tgt_lang="fr"
|
||||
)
|
||||
cls.pad_token_id = 1
|
||||
return cls
|
||||
|
||||
def check_language_codes(self):
|
||||
self.assertEqual(self.tokenizer.get_lang_id("ar"), 128006)
|
||||
self.assertEqual(self.tokenizer.get_lang_id("en"), 128022)
|
||||
self.assertEqual(self.tokenizer.get_lang_id("ro"), 128076)
|
||||
self.assertEqual(self.tokenizer.get_lang_id("mr"), 128063)
|
||||
|
||||
def test_tokenizer_batch_encode_plus(self):
|
||||
self.tokenizer.src_lang = "en"
|
||||
ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]
|
||||
self.assertListEqual(self.expected_src_tokens, ids)
|
||||
|
||||
def test_tokenizer_decode_ignores_language_codes(self):
|
||||
self.assertIn(FR_CODE, self.tokenizer.all_special_ids)
|
||||
# fmt: off
|
||||
generated_ids = [FR_CODE, 5364, 82, 8642, 4, 294, 47, 8, 14028, 136, 3286, 9706, 6, 90797, 6, 144012, 162, 88128, 30061, 5, 2]
|
||||
# fmt: on
|
||||
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||
expected_french = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
|
||||
self.assertEqual(result, expected_french)
|
||||
self.assertNotIn(self.tokenizer.eos_token, result)
|
||||
|
||||
def test_special_tokens_unaffacted_by_save_load(self):
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
original_special_tokens = self.tokenizer.lang_token_to_id
|
||||
self.tokenizer.save_pretrained(tmpdirname)
|
||||
new_tok = M2M100Tokenizer.from_pretrained(tmpdirname)
|
||||
self.assertDictEqual(new_tok.lang_token_to_id, original_special_tokens)
|
||||
|
||||
@require_torch
|
||||
def test_batch_fairseq_parity(self):
|
||||
self.tokenizer.src_lang = "en"
|
||||
self.tokenizer.tgt_lang = "fr"
|
||||
|
||||
batch = self.tokenizer(self.src_text, padding=True, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
batch["labels"] = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt").input_ids
|
||||
|
||||
batch["decoder_input_ids"] = shift_tokens_right(
|
||||
batch["labels"], self.tokenizer.pad_token_id, self.tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
for k in batch:
|
||||
batch[k] = batch[k].tolist()
|
||||
# batch = {k: v.tolist() for k,v in batch.items()}
|
||||
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||
# batch.decoder_inputs_ids[0][0] ==
|
||||
assert batch.input_ids[1][0] == EN_CODE
|
||||
assert batch.input_ids[1][-1] == 2
|
||||
assert batch.labels[1][0] == FR_CODE
|
||||
assert batch.labels[1][-1] == 2
|
||||
assert batch.decoder_input_ids[1][:2] == [2, FR_CODE]
|
||||
|
||||
@require_torch
|
||||
def test_src_lang_setter(self):
|
||||
self.tokenizer.src_lang = "mr"
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("mr")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
|
||||
self.tokenizer.src_lang = "zh"
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
|
||||
@require_torch
|
||||
def test_as_target_tokenizer(self):
|
||||
self.tokenizer.tgt_lang = "mr"
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("mr")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
|
||||
|
||||
self.tokenizer.tgt_lang = "zh"
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
|
|
@ -30,6 +30,8 @@ PATH_TO_DOC = "docs/source"
|
|||
# Being in this list is an exception and should **not** be the rule.
|
||||
IGNORE_NON_TESTED = [
|
||||
# models to ignore for not tested
|
||||
"M2M100Encoder", # Building part of bigger (tested) model.
|
||||
"M2M100Decoder", # Building part of bigger (tested) model.
|
||||
"LEDEncoder", # Building part of bigger (tested) model.
|
||||
"LEDDecoder", # Building part of bigger (tested) model.
|
||||
"BartDecoderWrapper", # Building part of bigger (tested) model.
|
||||
|
@ -75,6 +77,8 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
|||
# should **not** be the rule.
|
||||
IGNORE_NON_AUTO_CONFIGURED = [
|
||||
# models to ignore for model xxx mapping
|
||||
"M2M100Encoder",
|
||||
"M2M100Decoder",
|
||||
"LEDEncoder",
|
||||
"LEDDecoder",
|
||||
"BartDecoder",
|
||||
|
|
Загрузка…
Ссылка в новой задаче