Add "Leveraging Pretrained Checkpoints for Generation" Seq2Seq models. (#6594)
* add conversion script * improve conversion script * make style * add tryout files * fix * update * add causal bert * better names * add tokenizer file as well * finish causal_bert * fix small bugs * improve generate * change naming * renaming * renaming * renaming * remove leftover files * clean files * add fix tokenizer * finalize * correct slow test * update docs * small fixes * fix link * adapt check repo * apply sams and sylvains recommendations * fix import * implement Lysandres recommendations * fix logger warn
This commit is contained in:
Родитель
d1691d90e5
Коммит
7fd1febf38
|
@ -134,7 +134,10 @@ conversion utilities for the following models:
|
|||
26. `Funnel Transformer <https://github.com/laiguokun/Funnel-Transformer>`_ (from CMU/Google Brain) released with the paper
|
||||
`Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing
|
||||
<https://arxiv.org/abs/2006.03236>`_ by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
|
||||
27. `Other community models <https://huggingface.co/models>`_, contributed by the `community
|
||||
27. `Bert For Sequence Generation <https://tfhub.dev/s?module-type=text-generation&subtype=module,placeholder>`_ (from Google) released with the paper
|
||||
`Leveraging Pre-trained Checkpoints for Sequence Generation Tasks
|
||||
<https://arxiv.org/abs/1907.12461>`_ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
|
||||
28. `Other community models <https://huggingface.co/models>`_, contributed by the `community
|
||||
<https://huggingface.co/users>`_.
|
||||
|
||||
.. toctree::
|
||||
|
@ -221,6 +224,7 @@ conversion utilities for the following models:
|
|||
model_doc/mbart
|
||||
model_doc/funnel
|
||||
model_doc/lxmert
|
||||
model_doc/bertgeneration
|
||||
internal/modeling_utils
|
||||
internal/tokenization_utils
|
||||
internal/pipelines_utils
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
BertForSeqGeneration
|
||||
----------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The BertForSeqGeneration model is a BERT model that can be leveraged for sequence-to-sequence tasks using :class:`~transformers.EncoderDecoderModel` as proposed in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks <https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Unsupervised pre-training of large neural models has recently revolutionized Natural Language Processing. By warm-starting from the publicly released checkpoints, NLP practitioners have pushed the state-of-the-art on multiple benchmarks while saving significant amounts of compute time. So far the focus has been mainly on the Natural Language Understanding tasks. In this paper, we demonstrate the efficacy of pre-trained checkpoints for Sequence Generation. We developed a Transformer-based sequence-to-sequence model that is compatible with publicly available pre-trained BERT, GPT-2 and RoBERTa checkpoints and conducted an extensive empirical study on the utility of initializing our model, both encoder and decoder, with these checkpoints. Our models result in new state-of-the-art results on Machine Translation, Text Summarization, Sentence Splitting, and Sentence Fusion.*
|
||||
|
||||
Usage:
|
||||
|
||||
- The model can be used in combination with the :class:`~transformers.EncoderDecoderModel` to leverage two bert pretrained bert checkpoints for subsequent fine-tuning.
|
||||
|
||||
::
|
||||
|
||||
# leverage checkpoints for Bert2Bert model...
|
||||
encoder = BertGenerationEncoder.from_pretrained("bert-large-uncased", bos_token_id=101, eos_token_id=102) # use BERT's cls token as BOS token and sep token as EOS token
|
||||
decoder = BertGenerationDecoder.from_pretrained("bert-large-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102) # add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
|
||||
bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||
|
||||
# create tokenizer...
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
|
||||
|
||||
input_ids = tokenizer('This is a long article to summarize', add_special_tokens=False, return_tensors="pt").input_ids
|
||||
labels = tokenizer('This is a short summary', return_tensors="pt").input_ids
|
||||
|
||||
# train...
|
||||
loss = bert2bert(input_ids=input_ids, decoder_input_ids=labels, labels=labels, return_dict=True).loss
|
||||
loss.backward()
|
||||
|
||||
|
||||
- Pretrained :class:`~transformers.EncoderDecoderModel` are also directly available in the model hub, *e.g.*:
|
||||
|
||||
|
||||
::
|
||||
|
||||
# instantiate sentence fusion model
|
||||
sentence_fuser = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_discofuse")
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/roberta2roberta_L-24_discofuse")
|
||||
|
||||
input_ids = tokenizer('This is the first sentence. This is the second sentence.', add_special_tokens=False, return_tensors="pt").input_ids
|
||||
|
||||
outputs = sentence_fuser.generate(input_ids)
|
||||
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
|
||||
|
||||
Tips:
|
||||
|
||||
- :class:`~transformers.BertGenerationEncoder` and :class:`~transformers.BertGenerationDecoder` should be used in combination with :class:`~transformers.EncoderDecoder`.
|
||||
- For summarization, sentence splitting, sentence fusion and translation, no special tokens are required for the input. Therefore, no EOS token should be added to the end of the input.
|
||||
|
||||
The original code can be found `here <https://tfhub.dev/s?module-type=text-generation&subtype=module,placeholder>`__.
|
||||
|
||||
BertGenerationConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BertGenerationConfig
|
||||
:members:
|
||||
|
||||
|
||||
BertGenerationTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BertGenerationTokenizer
|
||||
:members:
|
||||
|
||||
BertGenerationEncoder
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BertGenerationEncoder
|
||||
:members:
|
||||
|
||||
|
||||
BertGenerationDecoder
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BertGenerationDecoder
|
||||
:members:
|
|
@ -22,6 +22,7 @@ from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertCo
|
|||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
|
||||
from .configuration_bart import BartConfig
|
||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_bert_generation import BertGenerationConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
|
||||
|
@ -142,6 +143,7 @@ from .tokenization_albert import AlbertTokenizer
|
|||
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
from .tokenization_bart import BartTokenizer, BartTokenizerFast
|
||||
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
|
||||
from .tokenization_bert_generation import BertGenerationTokenizer
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
from .tokenization_ctrl import CTRLTokenizer
|
||||
|
@ -279,6 +281,11 @@ if is_torch_available():
|
|||
BertPreTrainedModel,
|
||||
load_tf_weights_in_bert,
|
||||
)
|
||||
from .modeling_bert_generation import (
|
||||
BertGenerationDecoder,
|
||||
BertGenerationEncoder,
|
||||
load_tf_weights_in_bert_generation,
|
||||
)
|
||||
from .modeling_camembert import (
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
CamembertForCausalLM,
|
||||
|
|
|
@ -20,6 +20,7 @@ from collections import OrderedDict
|
|||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
|
||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||
from .configuration_bert_generation import BertGenerationConfig
|
||||
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
|
||||
|
@ -97,6 +98,10 @@ CONFIG_MAPPING = OrderedDict(
|
|||
"albert",
|
||||
AlbertConfig,
|
||||
),
|
||||
(
|
||||
"bert-for-seq-generation",
|
||||
BertGenerationConfig,
|
||||
),
|
||||
(
|
||||
"camembert",
|
||||
CamembertConfig,
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" BertForSeqGeneration model configuration """
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class BertGenerationConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.BertGenerationPreTrainedModel`.
|
||||
It is used to instantiate a BertGenerationConfig model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
|
||||
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
|
||||
for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 50358):
|
||||
Vocabulary size of the BertForSeqGeneration model. Defines the different tokens that
|
||||
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertForSeqGeneration`.
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 1024):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (:obj:`int`, `optional`, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
|
||||
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import BertGenerationConfig, BertGenerationEncoder
|
||||
|
||||
>>> # Initializing a BertForSeqGeneration config
|
||||
>>> configuration = BertGenerationConfig()
|
||||
|
||||
>>> # Initializing a modelfrom the config
|
||||
>>> model = BertGenerationEncoder(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
"""
|
||||
model_type = "bert-for-seq-generation"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50358,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
intermediate_size=4096,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
pad_token_id=0,
|
||||
bos_token_id=2,
|
||||
eos_token_id=1,
|
||||
gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
|
@ -0,0 +1,88 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Seq2Seq TF Hub checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import (
|
||||
BertConfig,
|
||||
BertGenerationConfig,
|
||||
BertGenerationDecoder,
|
||||
BertGenerationEncoder,
|
||||
load_tf_weights_in_bert_generation,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder):
|
||||
# Initialise PyTorch model
|
||||
bert_config = BertConfig.from_pretrained(
|
||||
"bert-large-cased",
|
||||
vocab_size=vocab_size,
|
||||
max_position_embeddings=512,
|
||||
is_decoder=True,
|
||||
add_cross_attention=True,
|
||||
)
|
||||
bert_config_dict = bert_config.to_dict()
|
||||
del bert_config_dict["type_vocab_size"]
|
||||
config = BertGenerationConfig(**bert_config_dict)
|
||||
if is_encoder:
|
||||
model = BertGenerationEncoder(config)
|
||||
else:
|
||||
model = BertGenerationDecoder(config)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_bert_generation(
|
||||
model,
|
||||
tf_hub_path,
|
||||
model_class="bert",
|
||||
is_encoder_named_decoder=is_encoder_named_decoder,
|
||||
is_encoder=is_encoder,
|
||||
)
|
||||
|
||||
# Save pytorch-model
|
||||
print("Save PyTorch model and config to {}".format(pytorch_dump_path))
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--tf_hub_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is_encoder_named_decoder",
|
||||
action="store_true",
|
||||
help="If decoder has to be renamed to encoder in PyTorch model.",
|
||||
)
|
||||
parser.add_argument("--is_encoder", action="store_true", help="If model is an encoder.")
|
||||
parser.add_argument("--vocab_size", default=50358, type=int, help="Vocab size of model")
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(
|
||||
args.tf_hub_path,
|
||||
args.pytorch_dump_path,
|
||||
args.is_encoder_named_decoder,
|
||||
args.vocab_size,
|
||||
is_encoder=args.is_encoder,
|
||||
)
|
|
@ -530,7 +530,7 @@ def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None, o
|
|||
code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
|
||||
elif "LMHead" in model_class:
|
||||
code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
|
||||
elif "Model" in model_class:
|
||||
elif "Model" in model_class or "Encoder" in model_class:
|
||||
code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
|
||||
else:
|
||||
raise ValueError(f"Docstring can't be built for model {model_class}")
|
||||
|
|
|
@ -383,7 +383,11 @@ class GenerationMixin:
|
|||
# see if BOS token can be used for decoder_start_token_id
|
||||
if bos_token_id is not None:
|
||||
decoder_start_token_id = bos_token_id
|
||||
elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"):
|
||||
elif (
|
||||
hasattr(self.config, "decoder")
|
||||
and hasattr(self.config.decoder, "bos_token_id")
|
||||
and self.config.decoder.bos_token_id is not None
|
||||
):
|
||||
decoder_start_token_id = self.config.decoder.bos_token_id
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
|
@ -23,6 +23,7 @@ from .configuration_auto import (
|
|||
AutoConfig,
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
BertGenerationConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DistilBertConfig,
|
||||
|
@ -73,6 +74,7 @@ from .modeling_bert import (
|
|||
BertLMHeadModel,
|
||||
BertModel,
|
||||
)
|
||||
from .modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder
|
||||
from .modeling_camembert import (
|
||||
CamembertForCausalLM,
|
||||
CamembertForMaskedLM,
|
||||
|
@ -213,6 +215,7 @@ MODEL_MAPPING = OrderedDict(
|
|||
(ReformerConfig, ReformerModel),
|
||||
(FunnelConfig, FunnelModel),
|
||||
(LxmertConfig, LxmertModel),
|
||||
(BertGenerationConfig, BertGenerationEncoder),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -284,6 +287,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
|||
), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now
|
||||
(CTRLConfig, CTRLLMHeadModel),
|
||||
(ReformerConfig, ReformerModelWithLMHead),
|
||||
(BertGenerationConfig, BertGenerationDecoder),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,505 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model specific for generation. """
|
||||
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .configuration_bert_generation import BertGenerationConfig
|
||||
from .file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_callable,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .modeling_bert import BertEncoder
|
||||
from .modeling_outputs import BaseModelOutput, CausalLMOutput
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "BertGenerationConfig"
|
||||
_TOKENIZER_FOR_DOC = "BertGenerationTokenizer"
|
||||
|
||||
|
||||
def load_tf_weights_in_bert_generation(
|
||||
model, tf_hub_path, model_class, is_encoder_named_decoder=False, is_encoder=False
|
||||
):
|
||||
try:
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
import tensorflow_hub as hub
|
||||
import tensorflow_text # noqa: F401
|
||||
|
||||
tf.disable_eager_execution()
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions."
|
||||
)
|
||||
raise
|
||||
tf_model = hub.Module(tf_hub_path)
|
||||
init = tf.global_variables_initializer()
|
||||
with tf.Session() as sess:
|
||||
init.run()
|
||||
all_variables = tf_model.variable_map
|
||||
keep_track_variables = all_variables.copy()
|
||||
for key in list(all_variables.keys()):
|
||||
if "global" in key:
|
||||
logger.info(f"Skipping {key}...")
|
||||
continue
|
||||
if not is_encoder:
|
||||
model_pointer = getattr(model, model_class)
|
||||
else:
|
||||
model_pointer = model
|
||||
is_embedding = False
|
||||
logger.info(f"Trying to match {key}...")
|
||||
# remove start_string = "module/bert/"
|
||||
sub_layers = key.split("/")[2:]
|
||||
if is_encoder_named_decoder and sub_layers[0] == "encoder":
|
||||
logger.info(f"Skipping encoder layer {key} for decoder")
|
||||
continue
|
||||
if is_encoder and sub_layers[0] == "decoder":
|
||||
logger.info(f"Skipping decoder layer {key} for encoder")
|
||||
continue
|
||||
for i, sub_layer in enumerate(sub_layers):
|
||||
if sub_layer == "embeddings":
|
||||
is_embedding = True
|
||||
elif sub_layer == "LayerNorm":
|
||||
is_embedding = False
|
||||
if "layer" in sub_layer:
|
||||
model_pointer = model_pointer.layer[int(sub_layer.split("_")[-1])]
|
||||
elif sub_layer in ["kernel", "gamma"]:
|
||||
model_pointer = model_pointer.weight
|
||||
elif sub_layer == "beta":
|
||||
model_pointer = model_pointer.bias
|
||||
elif sub_layer == "encdec":
|
||||
model_pointer = model_pointer.crossattention.self
|
||||
elif sub_layer == "encdec_output":
|
||||
model_pointer = model_pointer.crossattention.output
|
||||
elif is_encoder_named_decoder and sub_layer == "decoder":
|
||||
model_pointer = model_pointer.encoder
|
||||
else:
|
||||
if sub_layer == "attention" and "encdec" in sub_layers[i + 1]:
|
||||
continue
|
||||
try:
|
||||
model_pointer = getattr(model_pointer, sub_layer)
|
||||
except AttributeError:
|
||||
logger.info(f"Skipping to initialize {key} at {sub_layer}...")
|
||||
raise AttributeError
|
||||
|
||||
array = np.asarray(sess.run(all_variables[key]))
|
||||
if not is_embedding:
|
||||
logger.info("Transposing numpy weight of shape {} for {}".format(array.shape, key))
|
||||
array = np.transpose(array)
|
||||
else:
|
||||
model_pointer = model_pointer.weight
|
||||
|
||||
try:
|
||||
assert (
|
||||
model_pointer.shape == array.shape
|
||||
), f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched"
|
||||
except AssertionError as e:
|
||||
e.args += (model_pointer.shape, array.shape)
|
||||
raise
|
||||
logger.info(f"Initialize PyTorch weight {key}")
|
||||
|
||||
model_pointer.data = torch.from_numpy(array.astype(np.float32))
|
||||
keep_track_variables.pop(key, None)
|
||||
|
||||
logger.info("Weights not copied to PyTorch model: {}".format(", ".join(keep_track_variables.keys())))
|
||||
return model
|
||||
|
||||
|
||||
class BertGenerationEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
|
||||
def forward(self, input_ids=None, position_ids=None, inputs_embeds=None):
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
else:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertGenerationPreTrainedModel(PreTrainedModel):
|
||||
"""An abstract class to handle weights initialization and
|
||||
a simple interface for downloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = BertGenerationConfig
|
||||
base_model_prefix = "bert"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
BERT_GENERATION_START_DOCSTRING = r"""
|
||||
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
||||
usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.BertGenerationConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
BERT_GENERATION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`transformers.BertGenerationTokenizer`.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.__call__` for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||
|
||||
`What are position IDs? <../glossary.html#position-ids>`_
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
|
||||
plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare BertForSeqGeneration model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BERT_GENERATION_START_DOCSTRING,
|
||||
)
|
||||
class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
||||
"""
|
||||
|
||||
The model can behave as an encoder (with only self-attention) as well
|
||||
as a decoder, in which case a layer of cross-attention is added between
|
||||
the self-attention layers, following the architecture described in `Attention is all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani,
|
||||
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
|
||||
This model should be used when leveraging Bert or Roberta checkpoints for the `EncoderDecoderModel` class as described in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks <https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, and Aliaksei Severyn.
|
||||
|
||||
To behave as an decoder the model needs to be initialized with the
|
||||
:obj:`is_decoder` argument of the configuration set to :obj:`True`.
|
||||
To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
|
||||
argument and :obj:`add_cross_attention` set to :obj:`True`; an
|
||||
:obj:`encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embeddings = BertGenerationEmbeddings(config)
|
||||
self.encoder = BertEncoder(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_GENERATION_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder",
|
||||
output_type=BaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||
if the model is configured as a decoder.
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
||||
is used in the cross-attention if the model is configured as a decoder.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output,) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=sequence_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class BertGenerationOnlyLMHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
logits = self.decoder(hidden_states)
|
||||
return logits
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""BertGeneration Model with a `language modeling` head on top for CLM fine-tuning. """,
|
||||
BERT_GENERATION_START_DOCSTRING,
|
||||
)
|
||||
class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
if not config.is_decoder:
|
||||
logger.warn("If you want to use `BertGenerationDecoder` as a standalone, add `is_decoder=True.`")
|
||||
|
||||
self.bert = BertGenerationEncoder(config)
|
||||
self.lm_head = BertGenerationOnlyLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_GENERATION_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||||
if the model is configured as a decoder.
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
||||
is used in the cross-attention if the model is configured as a decoder.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the left-to-right language modeling loss (next word prediction).
|
||||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import BertGenerationTokenizer, BertGenerationDecoder, BertGenerationConfig
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = BertGenerationTokenizer.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder')
|
||||
>>> config = BertGenerationConfig.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
|
||||
>>> config.is_decoder = True
|
||||
>>> model = BertGenerationDecoder.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder', config=config, return_dict=True)
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> prediction_logits = outputs.logits
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[1:]
|
||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||
|
||||
return CausalLMOutput(
|
||||
loss=lm_loss,
|
||||
logits=prediction_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
@ -238,8 +238,21 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||
), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
|
||||
from .modeling_auto import AutoModel
|
||||
|
||||
if "config" not in kwargs_encoder:
|
||||
from .configuration_auto import AutoConfig
|
||||
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||
|
||||
logger.info(
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled."
|
||||
)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.add_cross_attention = False
|
||||
|
||||
kwargs_encoder["config"] = encoder_config
|
||||
|
||||
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
||||
encoder.config.is_decoder = False
|
||||
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
|
|
|
@ -22,10 +22,12 @@ from .configuration_auto import (
|
|||
AutoConfig,
|
||||
BartConfig,
|
||||
BertConfig,
|
||||
BertGenerationConfig,
|
||||
CamembertConfig,
|
||||
CTRLConfig,
|
||||
DistilBertConfig,
|
||||
ElectraConfig,
|
||||
EncoderDecoderConfig,
|
||||
FlaubertConfig,
|
||||
FunnelConfig,
|
||||
GPT2Config,
|
||||
|
@ -49,6 +51,7 @@ from .configuration_utils import PretrainedConfig
|
|||
from .tokenization_albert import AlbertTokenizer
|
||||
from .tokenization_bart import BartTokenizer, BartTokenizerFast
|
||||
from .tokenization_bert import BertTokenizer, BertTokenizerFast
|
||||
from .tokenization_bert_generation import BertGenerationTokenizer
|
||||
from .tokenization_bert_japanese import BertJapaneseTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
from .tokenization_ctrl import CTRLTokenizer
|
||||
|
@ -105,6 +108,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||
(FlaubertConfig, (FlaubertTokenizer, None)),
|
||||
(XLMConfig, (XLMTokenizer, None)),
|
||||
(CTRLConfig, (CTRLTokenizer, None)),
|
||||
(BertGenerationConfig, (BertGenerationTokenizer, None)),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -233,6 +237,14 @@ class AutoTokenizer:
|
|||
raise ValueError("Tokenizer class {} does not exist or is not currently imported.")
|
||||
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
# if model is an encoder decoder, the encoder tokenizer class is used by default
|
||||
if isinstance(config, EncoderDecoderConfig):
|
||||
if type(config.decoder) is not type(config.encoder): # noqa: E721
|
||||
logger.warn(
|
||||
f"The encoder model config class: {config.encoder.__class__} is different from the decoder model config class: {config.decoder.__class}. It is not recommended to use the `AutoTokenizer.from_pretrained(..)` method in this case. Please use the encoder and decoder specific tokenizer classes."
|
||||
)
|
||||
config = config.encoder
|
||||
|
||||
for config_class, (tokenizer_class_py, tokenizer_class_fast) in TOKENIZER_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
if tokenizer_class_fast and use_fast:
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Tokenization class for model BertForSeqGeneration."""
|
||||
|
||||
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import List
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
|
||||
|
||||
tokenizer_url = (
|
||||
"https://s3.amazonaws.com/models.huggingface.co/bert/google/bert_for_seq_generation_L-24_bbc_encoder/spiece.model"
|
||||
)
|
||||
|
||||
|
||||
class BertGenerationTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Constructs a BertGenerationTokenizer tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`__ .
|
||||
|
||||
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
||||
should refer to the superclass for more information regarding methods.
|
||||
|
||||
Args:
|
||||
vocab_file (:obj:`string`):
|
||||
`SentencePiece <https://github.com/google/sentencepiece>`__ file (generally has a `.spm` extension) that
|
||||
contains the vocabulary necessary to instantiate a tokenizer.
|
||||
eos_token (:obj:`string`, `optional`, defaults to :obj:`"</s>"`):
|
||||
The end of sequence token.
|
||||
bos_token (:obj:`string`, `optional`, defaults to :obj:`"<s>"`):
|
||||
The begin of sequence token.
|
||||
unk_token (:obj:`string`, `optional`, defaults to :obj:`"<unk>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
pad_token (:obj:`string`, `optional`, defaults to :obj:`"<pad>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
prefix_tokens: List[int] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
unk_token="<unk>",
|
||||
pad_token="<pad>",
|
||||
sep_token="<::::>",
|
||||
**kwargs
|
||||
):
|
||||
# Add extra_ids to the special token list
|
||||
super().__init__(
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
sep_token=sep_token,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"You need to install SentencePiece to use T5Tokenizer:"
|
||||
"https://github.com/google/sentencepiece"
|
||||
"pip install sentencepiece"
|
||||
)
|
||||
raise
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
|
||||
self.sp_model = spm.SentencePieceProcessor()
|
||||
self.sp_model.Load(vocab_file)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.sp_model.get_piece_size()
|
||||
|
||||
def get_vocab(self):
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["sp_model"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.__dict__ = d
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"You need to install SentencePiece to use BertGenerationTokenizer: https://github.com/google/sentencepiece"
|
||||
"pip install sentencepiece"
|
||||
)
|
||||
raise
|
||||
self.sp_model = spm.SentencePieceProcessor()
|
||||
self.sp_model.Load(self.vocab_file)
|
||||
|
||||
def _tokenize(self, text, sample=False):
|
||||
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
||||
if not sample:
|
||||
pieces = self.sp_model.EncodeAsPieces(text)
|
||||
else:
|
||||
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
|
||||
return pieces
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.sp_model.piece_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
token = self.sp_model.IdToPiece(index)
|
||||
return token
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
""" Converts a sequence of tokens (string) in a single string. """
|
||||
out_string = self.sp_model.decode_pieces(tokens)
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self, save_directory):
|
||||
"""Save the sentencepiece vocabulary (copy original file) and special tokens file
|
||||
to a directory.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||
return
|
||||
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
|
@ -22,7 +22,6 @@ from typing import List, Optional
|
|||
import sentencepiece as spm
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_xlnet import SPIECE_UNDERLINE
|
||||
from .utils import logging
|
||||
|
||||
|
||||
|
@ -47,6 +46,8 @@ SHARED_MODEL_IDENTIFIERS = [
|
|||
"Musixmatch/umberto-wikipedia-uncased-v1",
|
||||
]
|
||||
|
||||
SPIECE_UNDERLINE = "▁"
|
||||
|
||||
|
||||
class CamembertTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
|
@ -253,7 +254,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
|
|||
import sentencepiece as spm
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
|
||||
"You need to install SentencePiece to use CamembertTokenizer: https://github.com/google/sentencepiece"
|
||||
"pip install sentencepiece"
|
||||
)
|
||||
raise
|
||||
|
|
|
@ -27,8 +27,6 @@ from .utils import logging
|
|||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SPIECE_UNDERLINE = "▁"
|
||||
|
||||
####################################################
|
||||
# Mapping from the keyword arguments names of Tokenizer `__init__`
|
||||
# to file names for serializing Tokenizer instances
|
||||
|
|
|
@ -0,0 +1,234 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import BertGenerationConfig, BertGenerationDecoder, BertGenerationEncoder
|
||||
|
||||
|
||||
class BertGenerationEncoderTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=50,
|
||||
initializer_range=0.02,
|
||||
use_labels=True,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.use_labels = use_labels
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
if self.use_labels:
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
config = BertGenerationConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return config, input_ids, input_mask, token_labels
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
) = self.prepare_config_and_inputs()
|
||||
|
||||
config.is_decoder = True
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
**kwargs,
|
||||
):
|
||||
model = BertGenerationEncoder(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
**kwargs,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = BertGenerationEncoder(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
*args,
|
||||
):
|
||||
model = BertGenerationDecoder(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class BertGenerationEncoderTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (BertGenerationEncoder, BertGenerationDecoder) if is_torch_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BertGenerationEncoderTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BertGenerationConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_as_decoder(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
||||
|
||||
def test_model_as_decoder_with_default_input_mask(self):
|
||||
# This regression test was failing with PyTorch < 1.3
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
|
||||
input_mask = None
|
||||
|
||||
self.model_tester.create_and_check_model_as_decoder(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def test_for_causal_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = BertGenerationEncoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
|
||||
self.assertIsNotNone(model)
|
|
@ -21,6 +21,7 @@ from transformers import is_torch_available
|
|||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_modeling_bert import BertModelTester
|
||||
from .test_modeling_bert_generation import BertGenerationEncoderTester
|
||||
from .test_modeling_common import ids_tensor
|
||||
from .test_modeling_gpt2 import GPT2ModelTester
|
||||
from .test_modeling_roberta import RobertaModelTester
|
||||
|
@ -31,6 +32,9 @@ if is_torch_available():
|
|||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertGenerationDecoder,
|
||||
BertGenerationEncoder,
|
||||
BertLMHeadModel,
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
|
@ -489,6 +493,67 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||
self.assertEqual(summary, EXPECTED_SUMMARY)
|
||||
|
||||
|
||||
class BertForSeqGenerationEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model(self):
|
||||
return EncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"google/bert_for_seq_generation_L-24_bbc_encoder", "google/bert_for_seq_generation_L-24_bbc_encoder"
|
||||
)
|
||||
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = BertGenerationEncoder(config)
|
||||
decoder_model = BertGenerationDecoder(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester = BertGenerationEncoderTester(self)
|
||||
encoder_config_and_inputs = model_tester.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester.prepare_config_and_inputs_for_decoder()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
token_labels,
|
||||
) = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_input_mask,
|
||||
decoder_token_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_input_mask,
|
||||
"decoder_token_labels": decoder_token_labels,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
@slow
|
||||
def test_roberta2roberta_summarization(self):
|
||||
model = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_bbc")
|
||||
model.to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/roberta2roberta_L-24_bbc")
|
||||
|
||||
ARTICLE = """The problem is affecting people using the older versions of the PlayStation 3, called the "Fat" model.The problem isn't affecting the newer PS3 Slim systems that have been on sale since September last year.Sony have also said they are aiming to have the problem fixed shortly but is advising some users to avoid using their console for the time being."We hope to resolve this problem within the next 24 hours," a statement reads. "In the meantime, if you have a model other than the new slim PS3, we advise that you do not use your PS3 system, as doing so may result in errors in some functionality, such as recording obtained trophies, and not being able to restore certain data."We believe we have identified that this problem is being caused by a bug in the clock functionality incorporated in the system."The PlayStation Network is used by millions of people around the world.It allows users to play their friends at games like Fifa over the internet and also do things like download software or visit online stores."""
|
||||
|
||||
EXPECTED_SUMMARY = """Sony has said that a bug in its PlayStation 3 console is preventing them from using the machine as a computer."""
|
||||
|
||||
input_ids = tokenizer(ARTICLE, return_tensors="pt").input_ids.to(torch_device)
|
||||
output_ids = model.generate(input_ids)
|
||||
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(summary, EXPECTED_SUMMARY)
|
||||
|
||||
|
||||
class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = RobertaModel(config)
|
||||
|
|
|
@ -0,0 +1,210 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, slow
|
||||
from transformers.tokenization_bert_generation import BertGenerationTokenizer
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
SPIECE_UNDERLINE = "▁"
|
||||
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||
|
||||
|
||||
class BertForSeqGenerationTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = BertGenerationTokenizer
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
tokenizer = BertGenerationTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = BertGenerationTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
tokens = tokenizer.tokenize("This is a test")
|
||||
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens),
|
||||
[285, 46, 10, 170, 382],
|
||||
)
|
||||
|
||||
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(
|
||||
tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"9",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"é",
|
||||
".",
|
||||
],
|
||||
)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(
|
||||
ids,
|
||||
[8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4],
|
||||
)
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(
|
||||
back_tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"<unk>",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"<unk>",
|
||||
".",
|
||||
],
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def big_tokenizer(self):
|
||||
return BertGenerationTokenizer.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
|
||||
|
||||
@slow
|
||||
def test_tokenization_base_easy_symbols(self):
|
||||
symbols = "Hello World!"
|
||||
original_tokenizer_encodings = [18536, 2260, 101]
|
||||
|
||||
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
|
||||
|
||||
@slow
|
||||
def test_tokenization_base_hard_symbols(self):
|
||||
symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to <unk>, such as saoneuhaoesuth'
|
||||
original_tokenizer_encodings = [
|
||||
871,
|
||||
419,
|
||||
358,
|
||||
946,
|
||||
991,
|
||||
2521,
|
||||
452,
|
||||
358,
|
||||
1357,
|
||||
387,
|
||||
7751,
|
||||
3536,
|
||||
112,
|
||||
985,
|
||||
456,
|
||||
126,
|
||||
865,
|
||||
938,
|
||||
5400,
|
||||
5734,
|
||||
458,
|
||||
1368,
|
||||
467,
|
||||
786,
|
||||
2462,
|
||||
5246,
|
||||
1159,
|
||||
633,
|
||||
865,
|
||||
4519,
|
||||
457,
|
||||
582,
|
||||
852,
|
||||
2557,
|
||||
427,
|
||||
916,
|
||||
508,
|
||||
405,
|
||||
34324,
|
||||
497,
|
||||
391,
|
||||
408,
|
||||
11342,
|
||||
1244,
|
||||
385,
|
||||
100,
|
||||
938,
|
||||
985,
|
||||
456,
|
||||
574,
|
||||
362,
|
||||
12597,
|
||||
3200,
|
||||
3129,
|
||||
1172,
|
||||
]
|
||||
|
||||
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_torch_encode_plus_sent_to_model(self):
|
||||
import torch
|
||||
|
||||
from transformers import BertGenerationConfig, BertGenerationEncoder
|
||||
|
||||
# Build sequence
|
||||
first_ten_tokens = list(self.big_tokenizer.get_vocab().keys())[:10]
|
||||
sequence = " ".join(first_ten_tokens)
|
||||
encoded_sequence = self.big_tokenizer.encode_plus(sequence, return_tensors="pt", return_token_type_ids=False)
|
||||
batch_encoded_sequence = self.big_tokenizer.batch_encode_plus(
|
||||
[sequence + " " + sequence], return_tensors="pt", return_token_type_ids=False
|
||||
)
|
||||
|
||||
config = BertGenerationConfig()
|
||||
model = BertGenerationEncoder(config)
|
||||
|
||||
assert model.get_input_embeddings().weight.shape[0] >= self.big_tokenizer.vocab_size
|
||||
|
||||
with torch.no_grad():
|
||||
model(**encoded_sequence)
|
||||
model(**batch_encoded_sequence)
|
|
@ -21,11 +21,12 @@ from transformers import BatchEncoding
|
|||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import _torch_available
|
||||
from transformers.tokenization_t5 import T5Tokenizer
|
||||
from transformers.tokenization_xlnet import SPIECE_UNDERLINE
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
SPIECE_UNDERLINE = "▁"
|
||||
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||
|
||||
FRAMEWORK = "pt" if _torch_available else "tf"
|
||||
|
|
|
@ -47,6 +47,7 @@ MODEL_NAME_TO_DOC_FILE = {
|
|||
"openai": "gpt.rst",
|
||||
"transfo_xl": "transformerxl.rst",
|
||||
"xlm_roberta": "xlmroberta.rst",
|
||||
"bert_generation": "bertgeneration.rst",
|
||||
}
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
|
@ -230,6 +231,9 @@ def _get_model_name(module):
|
|||
# Secial case for xlm_roberta
|
||||
if splits[-1] == "roberta" and splits[-2] == "xlm":
|
||||
return "_".join(splits[-2:])
|
||||
# Special case for bert_generation
|
||||
if splits[-1] == "generation" and splits[-2] == "bert":
|
||||
return "_".join(splits[-2:])
|
||||
return splits[-1]
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче