Change model outputs types to self-document outputs (#5438)

* [WIP] Proposal for model outputs

* All Bert models

* Make CI green maybe?

* Fix ONNX test

* Isolate ModelOutput from pt and tf

* Formatting

* Add Electra models

* Auto-generate docstrings from outputs

* Add TF outputs

* Add some BERT models

* Revert TF side

* Remove last traces of TF changes

* Fail with a clear error message

* Add Albert and work through Bart

* Add CTRL and DistilBert

* Formatting

* Progress on Bart

* Renames and finish Bart

* Formatting

* Fix last test

* Add DPR

* Finish Electra and add FlauBERT

* Add GPT2

* Add Longformer

* Add MMBT

* Add MobileBert

* Add GPT

* Formatting

* Add Reformer

* Add Roberta

* Add T5

* Add Transformer XL

* Fix test

* Add XLM + fix XLMForTokenClassification

* Style + XLMRoberta

* Add XLNet

* Formatting

* Add doc of return_tuple arg
This commit is contained in:
Sylvain Gugger 2020-07-10 11:36:53 -04:00 коммит произвёл GitHub
Родитель fa265230a2
Коммит edfd82f5ff
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
33 изменённых файлов: 3458 добавлений и 2292 удалений

Просмотреть файл

@ -386,6 +386,9 @@ def start_memory_tracing(
elif isinstance(events_to_trace, (list, tuple)) and event not in events_to_trace:
return traceit
if "__name__" not in frame.f_globals:
return traceit
# Filter modules
name = frame.f_globals["__name__"]
if not isinstance(name, str):

Просмотреть файл

@ -49,6 +49,8 @@ class PretrainedConfig(object):
Whether or not the model should returns all attentions.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should return the last key/values attentions (not used by all models).
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should return tuples instead of :obj:`ModelOutput` objects.
is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether the model is used as an encoder/decoder or not.
is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
@ -131,6 +133,7 @@ class PretrainedConfig(object):
def __init__(self, **kwargs):
# Attributes with defaults
self.return_tuple = kwargs.pop("return_tuple", False)
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_attentions = kwargs.pop("output_attentions", False)
self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
@ -190,6 +193,11 @@ class PretrainedConfig(object):
logger.error("Can't set {} with value {} for {}".format(key, value, self))
raise err
@property
def use_return_tuple(self):
# If torchscript is set, force return_tuple to avoid jit errors
return self.return_tuple or self.torchscript
@property
def num_labels(self) -> int:
return len(self.id2label)

Просмотреть файл

@ -4,6 +4,7 @@ from os.path import abspath, dirname, exists
from typing import Dict, List, Optional, Tuple
from transformers import is_tf_available, is_torch_available
from transformers.file_utils import ModelOutput
from transformers.pipelines import Pipeline, pipeline
from transformers.tokenization_utils import BatchEncoding
@ -89,7 +90,8 @@ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], D
tokens = nlp.tokenizer("This is a sample output", return_tensors=framework)
seq_len = tokens.input_ids.shape[-1]
outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
if isinstance(outputs, ModelOutput):
outputs = outputs.to_tuple()
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)

Просмотреть файл

@ -8,6 +8,7 @@ import fnmatch
import json
import logging
import os
import re
import shutil
import sys
import tarfile
@ -186,6 +187,31 @@ def add_end_docstrings(*docstr):
return docstring_decorator
RETURN_INTRODUCTION = r"""
Returns:
:class:`~transformers.{output_type}` or :obj:`tuple(torch.FloatTensor)` (if ``return_tuple=True`` is passed or when ``config.return_tuple=True``) comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs:
"""
def _prepare_output_docstrings(output_type, config_class):
"""
Prepares the return part of the docstring using `output_type`.
"""
docstrings = output_type.__doc__
# Remove the head of the docstring to keep the list of args only
lines = docstrings.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
i += 1
if i < len(lines):
docstrings = "\n".join(lines[(i + 1) :])
# Add the return introduction
intro = RETURN_INTRODUCTION.format(output_type=output_type.__name__, config_class=config_class)
return intro + docstrings
PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example::
@ -414,7 +440,7 @@ TF_CAUSAL_LM_SAMPLE = r"""
"""
def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None):
def docstring_decorator(fn):
model_class = fn.__qualname__.split(".")[0]
is_tf_class = model_class[:2] == "TF"
@ -436,8 +462,29 @@ def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
else:
raise ValueError(f"Docstring can't be built for model {model_class}")
output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""
built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc
fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc
return fn
return docstring_decorator
def replace_return_docstrings(output_type=None, config_class=None):
def docstring_decorator(fn):
docstrings = fn.__doc__
lines = docstrings.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
i += 1
if i < len(lines):
lines[i] = _prepare_output_docstrings(output_type, config_class)
docstrings = "\n".join(lines)
else:
raise ValueError(
f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, current docstring is:\n{docstrings}"
)
fn.__doc__ = docstrings
return fn
return docstring_decorator
@ -806,3 +853,22 @@ def tf_required(func):
raise ImportError(f"Method `{func.__name__}` requires TF.")
return wrapper
class ModelOutput:
"""
Base class for all model outputs as dataclass. Has a ``__getitem__`` (to make it behave like a ``namedtuple``) that
will ignore ``None`` in the attributes.
"""
def to_tuple(self):
return tuple(getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None)
def to_dict(self):
return {f: getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None}
def __getitem__(self, i):
return self.to_dict()[i] if isinstance(i, str) else self.to_tuple()[i]
def __len__(self):
return len(self.to_tuple())

Просмотреть файл

@ -18,19 +18,37 @@ import logging
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from .configuration_albert import AlbertConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "AlbertConfig"
_TOKENIZER_FOR_DOC = "AlbertTokenizer"
@ -322,14 +340,18 @@ class AlbertTransformer(nn.Module):
self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
def forward(
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_tuple=False,
):
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
all_attentions = ()
if output_hidden_states:
all_hidden_states = (hidden_states,)
all_hidden_states = (hidden_states,) if output_hidden_states else None
all_attentions = () if output_attentions else None
for i in range(self.config.num_hidden_layers):
# Number of layers in a hidden group
@ -353,12 +375,11 @@ class AlbertTransformer(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
if return_tuple:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class AlbertPreTrainedModel(PreTrainedModel):
@ -383,6 +404,39 @@ class AlbertPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
@dataclass
class AlbertForPretrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.AlbertForPretrainingModel`.
Args:
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
sop_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False
continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
prediction_logits: torch.FloatTensor
sop_logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
ALBERT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
@ -432,6 +486,10 @@ ALBERT_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -487,7 +545,12 @@ class AlbertModel(AlbertPreTrainedModel):
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="albert-base-v2")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="albert-base-v2",
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -498,38 +561,13 @@ class AlbertModel(AlbertPreTrainedModel):
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during pre-training.
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -561,16 +599,22 @@ class AlbertModel(AlbertPreTrainedModel):
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
outputs = (sequence_output, pooled_output) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
return outputs
if return_tuple:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
@ -596,6 +640,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
return self.predictions.decoder
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AlbertForPretrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -608,6 +653,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
sentence_order_label=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs,
):
r"""
@ -625,26 +671,6 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
sop_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False
continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -668,6 +694,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
)
labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.albert(
input_ids,
@ -678,6 +705,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output, pooled_output = outputs[:2]
@ -685,16 +713,24 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
prediction_scores = self.predictions(sequence_output)
sop_scores = self.sop_classifier(pooled_output)
outputs = (prediction_scores, sop_scores,) + outputs[2:] # add hidden states and attention if they are here
total_loss = None
if labels is not None and sentence_order_label is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
total_loss = masked_lm_loss + sentence_order_loss
outputs = (total_loss,) + outputs
return outputs # (loss), prediction_scores, sop_scores, (hidden_states), (attentions)
if return_tuple:
output = (prediction_scores, sop_scores) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return AlbertForPretrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
sop_logits=sop_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class AlbertMLMHead(nn.Module):
@ -754,7 +790,12 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
return self.predictions.decoder
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="albert-base-v2")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="albert-base-v2",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -766,6 +807,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs
):
r"""
@ -776,24 +818,6 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
labels in ``[0, ..., config.vocab_size]``
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
if "masked_lm_labels" in kwargs:
warnings.warn(
@ -802,6 +826,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
)
labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.albert(
input_ids=input_ids,
@ -812,18 +837,27 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_outputs = outputs[0]
prediction_scores = self.predictions(sequence_outputs)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs
if return_tuple:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -843,7 +877,12 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="albert-base-v2")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="albert-base-v2",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -855,6 +894,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -862,25 +902,8 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
loss: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
logits ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.albert(
input_ids=input_ids,
@ -891,6 +914,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
pooled_output = outputs[1]
@ -898,8 +922,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
@ -908,9 +931,14 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
@add_start_docstrings(
@ -930,7 +958,12 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="albert-base-v2")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="albert-base-v2",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -942,30 +975,14 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.albert(
input_ids,
@ -976,6 +993,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -983,8 +1001,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
@ -995,9 +1012,14 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1016,7 +1038,12 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="albert-base-v2")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="albert-base-v2",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1029,6 +1056,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1039,27 +1067,8 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
loss: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
end_scores: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.albert(
input_ids=input_ids,
@ -1070,6 +1079,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -1079,7 +1089,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + outputs[2:]
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
@ -1095,9 +1105,18 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
if return_tuple:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1116,7 +1135,12 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="albert-base-v2")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="albert-base-v2",
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1128,33 +1152,15 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@ -1175,6 +1181,7 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
pooled_output = outputs[1]
@ -1183,11 +1190,15 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
if return_tuple:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)

Просмотреть файл

@ -32,12 +32,22 @@ from .file_utils import (
add_end_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPast,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
)
from .modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "BartConfig"
_TOKENIZER_FOR_DOC = "BartTokenizer"
@ -103,6 +113,10 @@ BART_INPUTS_DOCSTRING = r"""
See diagram 1 in the paper for more info on the default strategy
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -280,20 +294,22 @@ class BartEncoder(nn.Module):
# mbart has one extra layer_norm
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
def forward(self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False):
def forward(
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_tuple=False
):
"""
Args:
input_ids (LongTensor): tokens in the source language of shape
`(batch, src_len)`
attention_mask (torch.LongTensor): indicating which indices are padding tokens.
Returns:
Tuple comprised of:
BaseModelOutput or Tuple comprised of:
- **x** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
- **encoder_states** (tuple(torch.FloatTensor)): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *output_hidden_states:* is True.
- **all_attentions** (List[Tensor]): Attention weights for each layer.
- **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
# check attention mask and invert
@ -309,7 +325,8 @@ class BartEncoder(nn.Module):
# B x T x C -> T x B x C
x = x.transpose(0, 1)
encoder_states, all_attentions = [], []
encoder_states = [] if output_hidden_states else None
all_attentions = () if output_attentions else None
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states.append(x)
@ -321,18 +338,21 @@ class BartEncoder(nn.Module):
x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)
if output_attentions:
all_attentions.append(attn)
all_attentions = all_attentions + (attn,)
if self.layer_norm:
x = self.layer_norm(x)
if output_hidden_states:
encoder_states.append(x)
# T x B x C -> B x T x C
encoder_states = tuple(hidden_state.transpose(0, 1) for hidden_state in encoder_states)
# T x B x C -> B x T x C
encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states]
x = x.transpose(0, 1)
return x, encoder_states, all_attentions
if return_tuple:
return tuple(v for v in [x, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
class DecoderLayer(nn.Module):
@ -466,6 +486,7 @@ class BartDecoder(nn.Module):
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_tuple=False,
**unused,
):
"""
@ -481,8 +502,9 @@ class BartDecoder(nn.Module):
decoder_cached_states (dict or None): dictionary used for storing state during generation
Returns:
tuple:
BaseModelOutputWithPast or tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- the cache
- hidden states
- attentions
"""
@ -508,8 +530,8 @@ class BartDecoder(nn.Module):
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
# decoder layers
all_hidden_states = ()
all_self_attns = ()
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@ -540,7 +562,8 @@ class BartDecoder(nn.Module):
all_self_attns += (layer_self_attn,)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
all_hidden_states = [hidden_state.transpose(0, 1) for hidden_state in all_hidden_states]
if output_hidden_states:
all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states)
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
@ -548,7 +571,12 @@ class BartDecoder(nn.Module):
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
else:
next_cache = None
return x, next_cache, all_hidden_states, list(all_self_attns)
if return_tuple:
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns
)
def _reorder_buffer(attn_cache, new_order):
@ -792,11 +820,6 @@ def fill_with_neg_inf(t):
return t.float().fill_(float("-inf")).type_as(t)
def _filter_out_falsey_values(tup) -> Tuple:
"""Remove entries that are None or [] from an iterable."""
return tuple(x for x in tup if isinstance(x, torch.Tensor) or x)
# Public API
def _get_shape(t):
return getattr(t, "shape", None)
@ -818,7 +841,12 @@ class BartModel(PretrainedBartModel):
self.init_weights()
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="facebook/bart-large")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/bart-large",
output_type=BaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids,
@ -830,6 +858,7 @@ class BartModel(PretrainedBartModel):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
if decoder_input_ids is None:
@ -840,6 +869,7 @@ class BartModel(PretrainedBartModel):
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
# make masks if user doesn't supply
if not use_cache:
@ -861,8 +891,16 @@ class BartModel(PretrainedBartModel):
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
assert isinstance(encoder_outputs, tuple)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOuput when return_tuple=False
elif not return_tuple and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids,
@ -871,16 +909,24 @@ class BartModel(PretrainedBartModel):
decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_cached_states=decoder_cached_states,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
return_tuple=return_tuple,
)
# Attention and hidden_states will be [] or None if they aren't needed
decoder_outputs: Tuple = _filter_out_falsey_values(decoder_outputs)
assert isinstance(decoder_outputs[0], torch.Tensor)
encoder_outputs: Tuple = _filter_out_falsey_values(encoder_outputs)
return decoder_outputs + encoder_outputs
if return_tuple:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self):
return self.shared
@ -922,6 +968,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
self.register_buffer("final_logits_bias", new_bias)
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(BART_GENERATION_EXAMPLE)
def forward(
self,
@ -935,6 +982,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**unused,
):
r"""
@ -942,26 +990,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
Labels for computing the masked language modeling loss.
Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens
with labels
in ``[0, ..., config.vocab_size]``.
with labels in ``[0, ..., config.vocab_size]``.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Conditional generation example::
@ -987,6 +1018,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
DeprecationWarning,
)
labels = unused.pop("lm_labels")
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if labels is not None:
use_cache = False
@ -1001,16 +1033,30 @@ class BartForConditionalGeneration(PretrainedBartModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
# TODO(SS): do we need to ignore pad tokens in labels?
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs
if return_tuple:
output = (lm_logits,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return Seq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
decoder_past_key_values=outputs.decoder_past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
@ -1083,7 +1129,12 @@ class BartForSequenceClassification(PretrainedBartModel):
self.model._init_weights(self.classification_head.out_proj)
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="facebook/bart-large")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/bart-large",
output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids,
@ -1092,32 +1143,18 @@ class BartForSequenceClassification(PretrainedBartModel):
decoder_input_ids=None,
decoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
use_cache=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BartConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification loss (cross entropy)
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the
self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if labels is not None:
use_cache = False
@ -1127,9 +1164,10 @@ class BartForSequenceClassification(PretrainedBartModel):
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
return_tuple=return_tuple,
)
x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
@ -1137,13 +1175,25 @@ class BartForSequenceClassification(PretrainedBartModel):
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
logits = self.classification_head(sentence_representation)
# Prepend logits
outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here
if labels is not None: # prepend loss to output,
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs
loss = None
if labels is not None:
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
if return_tuple:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqSequenceClassifierOutput(
loss=loss,
logits=logits,
decoder_past_key_values=outputs.decoder_past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
@ -1164,7 +1214,12 @@ class BartForQuestionAnswering(PretrainedBartModel):
self.model._init_weights(self.qa_outputs)
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="facebook/bart-large")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/bart-large",
output_type=Seq2SeqQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids,
@ -1174,9 +1229,10 @@ class BartForQuestionAnswering(PretrainedBartModel):
decoder_attention_mask=None,
start_positions=None,
end_positions=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
use_cache=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1187,24 +1243,8 @@ class BartForQuestionAnswering(PretrainedBartModel):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BartConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if start_positions is not None and end_positions is not None:
use_cache = False
@ -1214,9 +1254,10 @@ class BartForQuestionAnswering(PretrainedBartModel):
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -1226,7 +1267,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + outputs[1:]
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
@ -1242,9 +1283,22 @@ class BartForQuestionAnswering(PretrainedBartModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # return outputs # (loss), start_logits, end_logits, encoder_outputs, (hidden_states), (attentions)
if return_tuple:
output = (start_logits, end_logits,) + outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output
return Seq2SeqQuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
decoder_past_key_values=outputs.decoder_past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
class SinusoidalPositionalEmbedding(nn.Embedding):

Просмотреть файл

@ -20,6 +20,8 @@ import logging
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
@ -28,12 +30,30 @@ from torch.nn import CrossEntropyLoss, MSELoss
from .activations import gelu, gelu_new, swish
from .configuration_bert import BertConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
CausalLMOutput,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -406,9 +426,10 @@ class BertEncoder(nn.Module):
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_tuple=False,
):
all_hidden_states = ()
all_attentions = ()
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -439,20 +460,17 @@ class BertEncoder(nn.Module):
output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
if return_tuple:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class BertPooler(nn.Module):
@ -561,6 +579,39 @@ class BertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
@dataclass
class BertForPretrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.BertForPretrainingModel`.
Args:
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False
continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
prediction_logits: torch.FloatTensor
seq_relationship_logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
BERT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
@ -618,7 +669,9 @@ BERT_INPUTS_DOCSTRING = r"""
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states tensors of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -668,7 +721,12 @@ class BertModel(BertPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -681,37 +739,13 @@ class BertModel(BertPreTrainedModel):
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during pre-training.
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -762,14 +796,20 @@ class BertModel(BertPreTrainedModel):
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
if return_tuple:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
@ -790,6 +830,7 @@ class BertForPreTraining(BertPreTrainedModel):
return self.cls.predictions.decoder
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=BertForPretrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -802,6 +843,7 @@ class BertForPreTraining(BertPreTrainedModel):
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs
):
r"""
@ -819,26 +861,6 @@ class BertForPreTraining(BertPreTrainedModel):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False
continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -861,6 +883,7 @@ class BertForPreTraining(BertPreTrainedModel):
)
labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.bert(
input_ids,
@ -871,23 +894,30 @@ class BertForPreTraining(BertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
outputs = (prediction_scores, seq_relationship_score,) + outputs[
2:
] # add hidden states and attention if they are here
total_loss = None
if labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
outputs = (total_loss,) + outputs
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
if return_tuple:
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return BertForPretrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -907,6 +937,7 @@ class BertLMHeadModel(BertPreTrainedModel):
return self.cls.predictions.decoder
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -920,6 +951,7 @@ class BertLMHeadModel(BertPreTrainedModel):
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs
):
r"""
@ -932,22 +964,6 @@ class BertLMHeadModel(BertPreTrainedModel):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Next token prediction loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Example::
@ -962,8 +978,9 @@ class BertLMHeadModel(BertPreTrainedModel):
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
>>> prediction_scores = outputs.prediction_scores
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.bert(
input_ids,
@ -976,22 +993,27 @@ class BertLMHeadModel(BertPreTrainedModel):
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :].contiguous()
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (ltr_lm_loss,) + outputs
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
if return_tuple:
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutput(
loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
@ -1020,7 +1042,12 @@ class BertForMaskedLM(BertPreTrainedModel):
return self.cls.predictions.decoder
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1034,6 +1061,7 @@ class BertForMaskedLM(BertPreTrainedModel):
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs
):
r"""
@ -1044,24 +1072,6 @@ class BertForMaskedLM(BertPreTrainedModel):
in ``[0, ..., config.vocab_size]``
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
if "masked_lm_labels" in kwargs:
warnings.warn(
@ -1072,6 +1082,8 @@ class BertForMaskedLM(BertPreTrainedModel):
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
@ -1083,19 +1095,27 @@ class BertForMaskedLM(BertPreTrainedModel):
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
if return_tuple:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
@ -1125,6 +1145,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -1136,6 +1157,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1145,24 +1167,8 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
``1`` indicates sequence B is a random sequence.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
Next sequence prediction (classification) loss.
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
Example::
>>> from transformers import BertTokenizer, BertForNextSentencePrediction
>>> import torch
@ -1174,9 +1180,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
>>> loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))
>>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1]))
>>> logits = outputs.seq_relationship_scores
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.bert(
input_ids,
@ -1187,19 +1195,28 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
pooled_output = outputs[1]
seq_relationship_score = self.cls(pooled_output)
seq_relationship_scores = self.cls(pooled_output)
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
next_sentence_loss = None
if next_sentence_label is not None:
loss_fct = CrossEntropyLoss()
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
outputs = (next_sentence_loss,) + outputs
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), next_sentence_label.view(-1))
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
if return_tuple:
output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
return NextSentencePredictorOutput(
loss=next_sentence_loss,
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1219,7 +1236,12 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1231,6 +1253,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1238,25 +1261,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.bert(
input_ids,
@ -1267,6 +1273,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
pooled_output = outputs[1]
@ -1274,8 +1281,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
@ -1284,9 +1290,14 @@ class BertForSequenceClassification(BertPreTrainedModel):
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1305,7 +1316,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1317,33 +1333,15 @@ class BertForMultipleChoice(BertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@ -1365,6 +1363,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
pooled_output = outputs[1]
@ -1373,14 +1372,18 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
if return_tuple:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1400,7 +1403,12 @@ class BertForTokenClassification(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1412,30 +1420,14 @@ class BertForTokenClassification(BertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.bert(
input_ids,
@ -1446,6 +1438,7 @@ class BertForTokenClassification(BertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -1453,7 +1446,7 @@ class BertForTokenClassification(BertPreTrainedModel):
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
@ -1466,9 +1459,14 @@ class BertForTokenClassification(BertPreTrainedModel):
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1487,7 +1485,12 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1500,6 +1503,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1510,27 +1514,8 @@ class BertForQuestionAnswering(BertPreTrainedModel):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.bert(
input_ids,
@ -1541,6 +1526,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -1550,7 +1536,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + outputs[2:]
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
@ -1566,6 +1552,15 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
if return_tuple:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

Просмотреть файл

@ -53,6 +53,10 @@ CAMEMBERT_START_DOCSTRING = r"""
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""

Просмотреть файл

@ -25,11 +25,13 @@ from torch.nn import CrossEntropyLoss
from .configuration_ctrl import CTRLConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "CTRLConfig"
_TOKENIZER_FOR_DOC = "CTRLTokenizer"
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -288,6 +290,10 @@ CTRL_INPUTS_DOCSTRING = r"""
can be used to speed up decoding (see `past`). Defaults to `True`.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -328,7 +334,12 @@ class CTRLModel(CTRLPreTrainedModel):
self.h[layer].multi_head_attention.prune_heads(heads)
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="ctrl",
output_type=BaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -341,32 +352,14 @@ class CTRLModel(CTRLPreTrainedModel):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -435,9 +428,9 @@ class CTRLModel(CTRLPreTrainedModel):
hidden_states = self.dropout(hidden_states)
output_shape = input_shape + (inputs_embeds.size(-1),)
presents = ()
all_hidden_states = ()
all_attentions = []
presents = () if use_cache else None
all_hidden_states = () if output_hidden_states else None
all_attentions = [] if output_attentions else None
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
@ -462,17 +455,20 @@ class CTRLModel(CTRLPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache is True:
outputs = outputs + (presents,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,)
return outputs
if return_tuple:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
@add_start_docstrings(
@ -499,7 +495,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="ctrl")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="ctrl",
output_type=CausalLMOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -513,6 +514,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -521,28 +523,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
Indices are selected in ``[-100, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.CTRLConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
Language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer(
input_ids,
past=past,
@ -554,14 +537,14 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
outputs = (lm_logits,) + transformer_outputs[1:]
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
@ -569,6 +552,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
if return_tuple:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

Просмотреть файл

@ -30,12 +30,26 @@ from torch.nn import CrossEntropyLoss
from .activations import gelu
from .configuration_distilbert import DistilBertConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "DistilBertConfig"
_TOKENIZER_FOR_DOC = "DistilBertTokenizer"
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -264,7 +278,9 @@ class Transformer(nn.Module):
layer = TransformerBlock(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False):
def forward(
self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False, return_tuple=None
):
"""
Parameters
----------
@ -284,8 +300,8 @@ class Transformer(nn.Module):
Tuple of length n_layers with the attention weights from each layer
Optional: only if output_attentions=True
"""
all_hidden_states = ()
all_attentions = ()
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_state = x
for i, layer_module in enumerate(self.layer):
@ -308,12 +324,11 @@ class Transformer(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
outputs = (hidden_state,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
if return_tuple:
return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
)
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
@ -379,6 +394,10 @@ DISTILBERT_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -410,6 +429,12 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.transformer.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="distilbert-base-uncased",
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="distilbert-base-uncased")
def forward(
self,
@ -419,28 +444,13 @@ class DistilBertModel(DistilBertPreTrainedModel):
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DistilBertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -461,17 +471,14 @@ class DistilBertModel(DistilBertPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
tfmr_output = self.transformer(
return self.transformer(
x=inputs_embeds,
attn_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
hidden_state = tfmr_output[0]
output = (hidden_state,) + tfmr_output[1:]
return output # last-layer hidden-state, (all hidden_states), (all attentions)
@add_start_docstrings(
@ -494,7 +501,12 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
return self.vocab_projector
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="distilbert-base-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="distilbert-base-uncased",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -504,6 +516,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs
):
r"""
@ -514,25 +527,6 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
in ``[0, ..., config.vocab_size]``
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DistilBertConfig`) and inputs:
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
if "masked_lm_labels" in kwargs:
warnings.warn(
@ -541,6 +535,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
)
labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
dlbrt_output = self.distilbert(
input_ids=input_ids,
@ -549,6 +544,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
@ -556,12 +552,20 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size)
outputs = (prediction_logits,) + dlbrt_output[1:]
mlm_loss = None
if labels is not None:
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
outputs = (mlm_loss,) + outputs
return outputs # (mlm_loss), prediction_logits, (all hidden_states), (all attentions)
if return_tuple:
output = (prediction_logits,) + dlbrt_output[1:]
return ((mlm_loss,) + output) if mlm_loss is not None else output
return MaskedLMOutput(
loss=mlm_loss,
logits=prediction_logits,
hidden_states=dlbrt_output.hidden_states,
attentions=dlbrt_output.attentions,
)
@add_start_docstrings(
@ -582,7 +586,12 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="distilbert-base-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="distilbert-base-uncased",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -592,6 +601,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -599,26 +609,9 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DistilBertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
distilbert_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
@ -626,6 +619,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
@ -634,7 +628,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
pooled_output = self.dropout(pooled_output) # (bs, dim)
logits = self.classifier(pooled_output) # (bs, dim)
outputs = (logits,) + distilbert_output[1:]
loss = None
if labels is not None:
if self.num_labels == 1:
loss_fct = nn.MSELoss()
@ -642,9 +636,17 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + distilbert_output[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=distilbert_output.hidden_states,
attentions=distilbert_output.attentions,
)
@add_start_docstrings(
@ -664,7 +666,12 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="distilbert-base-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="distilbert-base-uncased",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -675,6 +682,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -685,27 +693,9 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DistilBertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
distilbert_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
@ -713,6 +703,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
@ -722,7 +713,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
start_logits = start_logits.squeeze(-1) # (bs, max_query_len)
end_logits = end_logits.squeeze(-1) # (bs, max_query_len)
outputs = (start_logits, end_logits,) + distilbert_output[1:]
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
@ -738,9 +729,18 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
if return_tuple:
output = (start_logits, end_logits) + distilbert_output[1:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=distilbert_output.hidden_states,
attentions=distilbert_output.attentions,
)
@add_start_docstrings(
@ -760,7 +760,12 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="distilbert-base-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="distilbert-base-uncased",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -770,30 +775,14 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DistilBertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.distilbert(
input_ids,
@ -802,6 +791,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -809,7 +799,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[1:] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
@ -822,9 +812,14 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
@add_start_docstrings(
@ -844,6 +839,7 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
@replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -853,6 +849,7 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -861,24 +858,6 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -900,6 +879,7 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
>>> loss, logits = outputs[:2]
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@ -917,6 +897,7 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
hidden_state = outputs[0] # (bs * num_choices, seq_len, dim)
@ -928,11 +909,15 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
reshaped_logits = logits.view(-1, num_choices) # (bs, num_choices)
outputs = (reshaped_logits,) + outputs[1:] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
if return_tuple:
output = (reshaped_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)

Просмотреть файл

@ -16,19 +16,23 @@
import logging
from typing import Optional, Tuple
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch import Tensor, nn
from .configuration_dpr import DPRConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings
from .modeling_bert import BertModel
from .modeling_outputs import BaseModelOutputWithPooling
from .modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "DPRConfig"
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/dpr-ctx_encoder-single-nq-base",
]
@ -40,6 +44,102 @@ DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
##########
# Outputs
##########
@dataclass
class DPRContextEncoderOutput(ModelOutput):
"""
Class for outputs of :class:`~transformers.DPRQuestionEncoder`.
Args:
pooler_output: (:obj:``torch.FloatTensor`` of shape ``(batch_size, embeddings_size)``):
The DPR encoder outputs the `pooler_output` that corresponds to the context representation.
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer. This output is to be used to embed contexts for
nearest neighbors queries with questions embeddings.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
pooler_output: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class DPRQuestionEncoderOutput(ModelOutput):
"""
Class for outputs of :class:`~transformers.DPRQuestionEncoder`.
Args:
pooler_output: (:obj:``torch.FloatTensor`` of shape ``(batch_size, embeddings_size)``):
The DPR encoder outputs the `pooler_output` that corresponds to the question representation.
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer. This output is to be used to embed questions for
nearest neighbors queries with context embeddings.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
pooler_output: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class DPRReaderOutput(ModelOutput):
"""
Class for outputs of :class:`~transformers.DPRQuestionEncoder`.
Args:
start_logits: (:obj:``torch.FloatTensor`` of shape ``(n_passages, sequence_length)``):
Logits of the start index of the span for each passage.
end_logits: (:obj:``torch.FloatTensor`` of shape ``(n_passages, sequence_length)``):
Logits of the end index of the span for each passage.
relevance_logits: (:obj:`torch.FloatTensor`` of shape ``(n_passages, )``):
Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage
to answer the question, compared to all the other passages.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
start_logits: torch.FloatTensor
end_logits: torch.FloatTensor
relevance_logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class DPREncoder(PreTrainedModel):
base_model_prefix = "bert_model"
@ -61,28 +161,31 @@ class DPREncoder(PreTrainedModel):
inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
) -> Tuple[Tensor, ...]:
return_tuple: bool = False,
) -> Union[BaseModelOutputWithPooling, Tuple[Tensor, ...]]:
outputs = self.bert_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output, pooled_output, hidden_states = outputs[:3]
sequence_output, pooled_output = outputs[:2]
pooled_output = sequence_output[:, 0, :]
if self.projection_dim > 0:
pooled_output = self.encode_proj(pooled_output)
dpr_encoder_outputs = (sequence_output, pooled_output)
if return_tuple:
return (sequence_output, pooled_output) + outputs[2:]
if output_hidden_states:
dpr_encoder_outputs += (hidden_states,)
if output_attentions:
dpr_encoder_outputs += (outputs[-1],)
return dpr_encoder_outputs
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@property
def embeddings_size(self) -> int:
@ -114,7 +217,8 @@ class DPRSpanPredictor(PreTrainedModel):
inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
):
return_tuple: bool = False,
) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]:
# notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2]
# feed encoder
@ -124,6 +228,7 @@ class DPRSpanPredictor(PreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -133,12 +238,22 @@ class DPRSpanPredictor(PreTrainedModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
relevance_logits = self.qa_classifier(sequence_output[:, 0, :])
# resize and return
return (
start_logits.view(n_passages, sequence_length),
end_logits.view(n_passages, sequence_length),
relevance_logits.view(n_passages),
) + outputs[2:]
# resize
start_logits = start_logits.view(n_passages, sequence_length)
end_logits = end_logits.view(n_passages, sequence_length)
relevance_logits = relevance_logits.view(n_passages)
if return_tuple:
return (start_logits, end_logits, relevance_logits) + outputs[2:]
return DPRReaderOutput(
start_logits=start_logits,
end_logits=end_logits,
relevance_logits=relevance_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def init_weights(self):
self.encoder.init_weights()
@ -288,6 +403,7 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
self.init_weights()
@add_start_docstrings_to_callable(DPR_ENCODERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[Tensor] = None,
@ -296,26 +412,10 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
inputs_embeds: Optional[Tensor] = None,
output_attentions=None,
output_hidden_states=None,
) -> Tensor:
return_tuple=None,
) -> Union[DPRContextEncoderOutput, Tuple[Tensor, ...]]:
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DPRConfig`) and inputs:
pooler_output: (:obj:``torch.FloatTensor`` of shape ``(batch_size, embeddings_size)``):
The DPR encoder outputs the `pooler_output` that corresponds to the context representation.
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer. This output is to be used to embed contexts for
nearest neighbors queries with questions embeddings.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -331,6 +431,7 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -359,9 +460,14 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
if return_tuple:
return outputs[1:]
return DPRContextEncoderOutput(
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
sequence_output, pooled_output = outputs[:2]
return (pooled_output,) + outputs[2:]
@add_start_docstrings(
@ -376,6 +482,7 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
self.init_weights()
@add_start_docstrings_to_callable(DPR_ENCODERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[Tensor] = None,
@ -384,26 +491,10 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
inputs_embeds: Optional[Tensor] = None,
output_attentions=None,
output_hidden_states=None,
) -> Tensor:
return_tuple=None,
) -> Union[DPRQuestionEncoderOutput, Tuple[Tensor, ...]]:
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DPRConfig`) and inputs:
pooler_output: (:obj:``torch.FloatTensor`` of shape ``(batch_size, embeddings_size)``):
The DPR encoder outputs the `pooler_output` that corresponds to the question representation.
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer. This output is to be used to embed questions for
nearest neighbors queries with context embeddings.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -417,6 +508,7 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -445,9 +537,14 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
if return_tuple:
return outputs[1:]
return DPRQuestionEncoderOutput(
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
sequence_output, pooled_output = outputs[:2]
return (pooled_output,) + outputs[2:]
@add_start_docstrings(
@ -461,6 +558,7 @@ class DPRReader(DPRPretrainedReader):
self.init_weights()
@add_start_docstrings_to_callable(DPR_READER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DPRReaderOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[Tensor] = None,
@ -468,30 +566,10 @@ class DPRReader(DPRPretrainedReader):
inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = None,
output_hidden_states: bool = None,
) -> Tuple[Tensor, ...]:
return_tuple=None,
) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]:
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DPRConfig`) and inputs:
input_ids: (:obj:``torch.FloatTensor`` of shape ``(n_passages, sequence_length)``)
They correspond to the combined `input_ids` from `(question + context title + context content`).
start_logits: (:obj:``torch.FloatTensor`` of shape ``(n_passages, sequence_length)``):
Logits of the start index of the span for each passage.
end_logits: (:obj:``torch.FloatTensor`` of shape ``(n_passages, sequence_length)``):
Logits of the end index of the span for each passage.
relevance_logits: (:obj:`torch.FloatTensor`` of shape ``(n_passages, )``):
Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage
to answer the question, compared to all the other passages.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -514,6 +592,7 @@ class DPRReader(DPRPretrainedReader):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -529,13 +608,11 @@ class DPRReader(DPRPretrainedReader):
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
span_outputs = self.span_predictor(
return self.span_predictor(
input_ids,
attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
start_logits, end_logits, relevance_logits = span_outputs[:3]
return (start_logits, end_logits, relevance_logits) + span_outputs[3:]

Просмотреть файл

@ -1,6 +1,8 @@
import logging
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
@ -8,13 +10,28 @@ from torch.nn import CrossEntropyLoss, MSELoss
from .activations import get_activation
from .configuration_electra import ElectraConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_bert import BertEmbeddings, BertEncoder, BertLayerNorm, BertPreTrainedModel
from .modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import SequenceSummary
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "ElectraConfig"
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -168,6 +185,35 @@ class ElectraPreTrainedModel(BertPreTrainedModel):
base_model_prefix = "electra"
@dataclass
class ElectraForPretrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.ElectraForPretrainingModel`.
Args:
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total loss of the ELECTRA objective.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`)
Prediction scores of the head (scores for each token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
ELECTRA_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
@ -224,6 +270,10 @@ ELECTRA_INPUTS_DOCSTRING = r"""
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -265,7 +315,12 @@ class ElectraModel(ElectraPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator",
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -276,29 +331,13 @@ class ElectraModel(ElectraPreTrainedModel):
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -332,6 +371,7 @@ class ElectraModel(ElectraPreTrainedModel):
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
return hidden_states
@ -371,7 +411,12 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -383,6 +428,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -390,25 +436,9 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
discriminator_hidden_states = self.electra(
input_ids,
attention_mask,
@ -418,13 +448,13 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
inputs_embeds,
output_attentions,
output_hidden_states,
return_tuple,
)
sequence_output = discriminator_hidden_states[0]
logits = self.classifier(sequence_output)
outputs = (logits,) + discriminator_hidden_states[1:] # add hidden states and attention if they are here
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
@ -433,9 +463,17 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + discriminator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=discriminator_hidden_states.hidden_states,
attentions=discriminator_hidden_states.attentions,
)
@add_start_docstrings(
@ -455,6 +493,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ElectraForPretrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -466,6 +505,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
@ -475,23 +515,6 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
``1`` indicates the token was replaced.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total loss of the ELECTRA objective.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`)
Prediction scores of the head (scores for each token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -505,6 +528,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
>>> scores = model(input_ids)[0]
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
discriminator_hidden_states = self.electra(
input_ids,
@ -515,13 +539,13 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
inputs_embeds,
output_attentions,
output_hidden_states,
return_tuple,
)
discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.discriminator_predictions(discriminator_sequence_output)
output = (logits,)
loss = None
if labels is not None:
loss_fct = nn.BCEWithLogitsLoss()
if attention_mask is not None:
@ -532,11 +556,16 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
else:
loss = loss_fct(logits.view(-1, discriminator_sequence_output.shape[1]), labels.float())
output = (loss,) + output
if return_tuple:
output = (logits,) + discriminator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output
output += discriminator_hidden_states[1:]
return output # (loss), scores, (hidden_states), (attentions)
return ElectraForPretrainingOutput(
loss=loss,
logits=logits,
hidden_states=discriminator_hidden_states.hidden_states,
attentions=discriminator_hidden_states.attentions,
)
@add_start_docstrings(
@ -561,7 +590,12 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
return self.generator_lm_head
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-generator")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -573,6 +607,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs
):
r"""
@ -583,24 +618,6 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
in ``[0, ..., config.vocab_size]``
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
if "masked_lm_labels" in kwargs:
warnings.warn(
@ -609,6 +626,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
)
labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
generator_hidden_states = self.electra(
input_ids,
@ -619,23 +637,29 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
inputs_embeds,
output_attentions,
output_hidden_states,
return_tuple,
)
generator_sequence_output = generator_hidden_states[0]
prediction_scores = self.generator_predictions(generator_sequence_output)
prediction_scores = self.generator_lm_head(prediction_scores)
output = (prediction_scores,)
loss = None
# Masked language modeling softmax layer
if labels is not None:
loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
output = (loss,) + output
output += generator_hidden_states[1:]
if return_tuple:
output = (prediction_scores,) + generator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output
return output # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
return MaskedLMOutput(
loss=loss,
logits=prediction_scores,
hidden_states=generator_hidden_states.hidden_states,
attentions=generator_hidden_states.attentions,
)
@add_start_docstrings(
@ -655,7 +679,12 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -667,30 +696,14 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
discriminator_hidden_states = self.electra(
input_ids,
@ -701,14 +714,14 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
inputs_embeds,
output_attentions,
output_hidden_states,
return_tuple,
)
discriminator_sequence_output = discriminator_hidden_states[0]
discriminator_sequence_output = self.dropout(discriminator_sequence_output)
logits = self.classifier(discriminator_sequence_output)
output = (logits,)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
# Only keep active parts of the loss
@ -720,11 +733,16 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
else:
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
output = (loss,) + output
if return_tuple:
output = (logits,) + discriminator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output
output += discriminator_hidden_states[1:]
return output # (loss), scores, (hidden_states), (attentions)
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=discriminator_hidden_states.hidden_states,
attentions=discriminator_hidden_states.attentions,
)
@add_start_docstrings(
@ -747,7 +765,12 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -760,6 +783,7 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -770,27 +794,8 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
discriminator_hidden_states = self.electra(
input_ids,
@ -810,7 +815,7 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + discriminator_hidden_states[1:]
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
@ -826,9 +831,18 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
if return_tuple:
output = (start_logits, end_logits,) + discriminator_hidden_states[1:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=discriminator_hidden_states.hidden_states,
attentions=discriminator_hidden_states.attentions,
)
@add_start_docstrings(
@ -847,7 +861,12 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator",
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -858,33 +877,15 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@ -905,6 +906,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
return_tuple=return_tuple,
)
sequence_output = discriminator_hidden_states[0]
@ -913,13 +915,18 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
outputs = (reshaped_logits,) + discriminator_hidden_states[
1:
] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
if return_tuple:
output = (reshaped_logits,) + discriminator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=discriminator_hidden_states.hidden_states,
attentions=discriminator_hidden_states.attentions,
)

Просмотреть файл

@ -273,6 +273,7 @@ class EncoderDecoderModel(PreTrainedModel):
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
return_tuple=True,
**kwargs_encoder,
)
@ -287,6 +288,7 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
labels=labels,
return_tuple=True,
**kwargs_decoder,
)

Просмотреть файл

@ -23,6 +23,7 @@ from torch.nn import functional as F
from .configuration_flaubert import FlaubertConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_outputs import BaseModelOutput
from .modeling_xlm import (
XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple,
@ -35,6 +36,7 @@ from .modeling_xlm import (
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "FlaubertConfig"
_TOKENIZER_FOR_DOC = "FlaubertTokenizer"
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -104,6 +106,10 @@ FLAUBERT_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -121,7 +127,12 @@ class FlaubertModel(XLMModel):
self.pre_norm = getattr(config, "pre_norm", False)
@add_start_docstrings_to_callable(FLAUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="flaubert/flaubert_base_cased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="flaubert/flaubert_base_cased",
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -135,28 +146,13 @@ class FlaubertModel(XLMModel):
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
# removed: src_enc=None, src_len=None
if input_ids is not None:
@ -227,8 +223,8 @@ class FlaubertModel(XLMModel):
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# transformer layers
hidden_states = ()
attentions = ()
hidden_states = () if output_hidden_states else None
attentions = () if output_attentions else None
for i in range(self.n_layers):
# LayerDrop
dropout_probability = random.uniform(0, 1)
@ -286,12 +282,10 @@ class FlaubertModel(XLMModel):
# move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1)
outputs = (tensor,)
if output_hidden_states:
outputs = outputs + (hidden_states,)
if output_attentions:
outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions)
if return_tuple:
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
return BaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
@add_start_docstrings(

Просмотреть файл

@ -19,6 +19,8 @@
import logging
import os
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
@ -26,7 +28,14 @@ from torch.nn import CrossEntropyLoss
from .activations import ACT2FN
from .configuration_gpt2 import GPT2Config
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .modeling_utils import (
Conv1D,
PreTrainedModel,
@ -38,6 +47,7 @@ from .modeling_utils import (
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "GPT2Config"
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -280,6 +290,48 @@ class GPT2PreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
@dataclass
class GPT2DoubleHeadsModelOutput(ModelOutput):
"""
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
Language modeling loss.
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
Multiple choice classification loss.
lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
``past_key_values`` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
lm_loss: Optional[torch.FloatTensor]
mc_loss: Optional[torch.FloatTensor]
lm_logits: torch.FloatTensor
mc_logits: torch.FloatTensor
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
GPT2_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
@ -339,6 +391,10 @@ GPT2_INPUTS_DOCSTRING = r"""
If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -372,7 +428,12 @@ class GPT2Model(GPT2PreTrainedModel):
self.h[layer].attn.prune_heads(heads)
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="gpt2")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="gpt2",
output_type=BaseModelOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -385,33 +446,14 @@ class GPT2Model(GPT2PreTrainedModel):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
If `past` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``) is passed or when ``config.output_hidden_states=True``:
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -477,9 +519,9 @@ class GPT2Model(GPT2PreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),)
presents = ()
all_attentions = []
all_hidden_states = ()
presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
@ -498,7 +540,7 @@ class GPT2Model(GPT2PreTrainedModel):
presents = presents + (present,)
if output_attentions:
all_attentions.append(outputs[2])
all_attentions = all_attentions + (outputs[2],)
hidden_states = self.ln_f(hidden_states)
@ -507,17 +549,15 @@ class GPT2Model(GPT2PreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache is True:
outputs = outputs + (presents,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,)
return outputs # last hidden state, (presents), (all hidden_states), (attentions)
if return_tuple:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
@add_start_docstrings(
@ -544,7 +584,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="gpt2")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="ctrl",
output_type=CausalLMOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -558,6 +603,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -566,28 +612,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
Indices are selected in ``[-100, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
Language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer(
input_ids,
past=past,
@ -599,12 +626,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
outputs = (lm_logits,) + transformer_outputs[1:]
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
@ -612,9 +640,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
if return_tuple:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
@ -639,6 +676,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
return self.lm_head
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -654,6 +692,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs
):
r"""
@ -674,29 +713,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
Used to hide legacy arguments that have been deprecated.
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
Language modeling loss.
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
Multiple choice classification loss.
lm_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -729,6 +745,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
)
labels = kwargs.pop("lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer(
input_ids,
@ -741,6 +758,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
hidden_states = transformer_outputs[0]
@ -748,16 +766,29 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
mc_loss = None
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
outputs = (loss,) + outputs
mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
lm_loss = None
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
outputs = (loss,) + outputs
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
if return_tuple:
output = (lm_logits, mc_logits) + transformer_outputs[1:]
if mc_loss is not None:
output = (mc_loss,) + output
return ((lm_loss,) + output) if lm_loss is not None else output
return GPT2DoubleHeadsModelOutput(
lm_loss=lm_loss,
mc_loss=mc_loss,
lm_logits=lm_logits,
mc_logits=mc_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

Просмотреть файл

@ -24,14 +24,29 @@ from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from .configuration_longformer import LongformerConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_bert import BertIntermediate, BertLayerNorm, BertOutput, BertPooler, BertPreTrainedModel, BertSelfOutput
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_roberta import RobertaEmbeddings, RobertaLMHead
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "LongformerConfig"
_TOKENIZER_FOR_DOC = "LongformerTokenizer"
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -672,10 +687,15 @@ class LongformerEncoder(nn.Module):
self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)])
def forward(
self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False,
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_tuple=False,
):
all_hidden_states = ()
all_attentions = ()
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -702,12 +722,11 @@ class LongformerEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
if return_tuple:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class LongformerPreTrainedModel(PreTrainedModel):
@ -788,6 +807,10 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -906,6 +929,7 @@ class LongformerModel(LongformerPreTrainedModel):
return attention_mask
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -916,24 +940,11 @@ class LongformerModel(LongformerPreTrainedModel):
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -959,6 +970,7 @@ class LongformerModel(LongformerPreTrainedModel):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -1002,24 +1014,25 @@ class LongformerModel(LongformerPreTrainedModel):
attention_mask=extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
# undo padding
if padding_len > 0:
# `output` has the following tensors: sequence_output, pooled_output, (hidden_states), (attentions)
# `sequence_output`: unpad because the calling function is expecting a length == input_ids.size(1)
# `pooled_output`: independent of the sequence length
# `hidden_states`: mainly used for debugging and analysis, so keep the padding
# `attentions`: mainly used for debugging and analysis, so keep the padding
outputs = outputs[0][:, :-padding_len], *outputs[1:]
# unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1)
sequence_output = sequence_output[:, :-padding_len]
return outputs
if return_tuple:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
@ -1036,6 +1049,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -1047,6 +1061,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs
):
r"""
@ -1059,22 +1074,6 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -1099,6 +1098,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
)
labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.longformer(
input_ids,
@ -1109,18 +1109,26 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
if return_tuple:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1142,7 +1150,12 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="allenai/longformer-base-4096")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1154,6 +1167,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1161,25 +1175,8 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if global_attention_mask is None:
logger.info("Initializing global attention on CLS token...")
@ -1196,11 +1193,12 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:]
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
@ -1209,9 +1207,14 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
class LongformerClassificationHead(nn.Module):
@ -1252,6 +1255,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -1264,6 +1268,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1275,24 +1280,6 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -1317,6 +1304,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
>>> answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens)) # remove space prepending space token
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
# set global attention on question tokens
if global_attention_mask is None:
@ -1333,6 +1321,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -1342,7 +1331,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + outputs[2:]
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
@ -1358,9 +1347,18 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
if return_tuple:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1383,7 +1381,12 @@ class LongformerForTokenClassification(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="allenai/longformer-base-4096")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1395,30 +1398,14 @@ class LongformerForTokenClassification(BertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.longformer(
input_ids,
@ -1429,6 +1416,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -1436,8 +1424,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
@ -1450,9 +1437,14 @@ class LongformerForTokenClassification(BertPreTrainedModel):
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1474,7 +1466,12 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="allenai/longformer-base-4096")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096",
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1486,34 +1483,16 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
loss (:obj:`torch.FloatTensor`` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
# set global attention on question tokens
if global_attention_mask is None:
@ -1551,6 +1530,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
pooled_output = outputs[1]
@ -1558,11 +1538,15 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
if return_tuple:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)

Просмотреть файл

@ -22,12 +22,15 @@ import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import add_start_docstrings
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings
from .modeling_outputs import BaseModelOutputWithPooling
from .modeling_utils import ModuleUtilsMixin
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "MMBTConfig"
class ModalEmbeddings(nn.Module):
"""Generic Modal Embeddings which takes in an encoder, and a transformer embedding.
@ -100,91 +103,68 @@ MMBT_START_DOCSTRING = r""" MMBT model was proposed in
"""
MMBT_INPUTS_DOCSTRING = r""" Inputs:
**input_modal**: ``torch.FloatTensor`` of shape ``(batch_size, ***)``:
input_modal (``torch.FloatTensor`` of shape ``(batch_size, ***)``):
The other modality data. It will be the shape that the encoder for that type expects.
e.g. With an Image Encoder, the shape would be (batch_size, channels, height, width)
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
input_ids (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``):
Indices of input sequence tokens in the vocabulary.
It does not expect [CLS] token to be added as it's appended to the end of other modality embeddings.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**modal_start_tokens**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
modal_start_tokens (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
Optional start token to be added to Other Modality Embedding. [CLS] Most commonly used for Classification tasks.
**modal_end_tokens**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
modal_end_tokens (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
Optional end token to be added to Other Modality Embedding. [SEP] Most commonly used.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
attention_mask (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
token_type_ids (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Segment token indices to indicate different portions of the inputs.
**modal_token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``:
modal_token_type_ids (`optional`) ``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``:
Segment token indices to indicate different portions of the non-text modality.
The embeddings from these tokens will be summed with the respective token embeddings for the non-text modality.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
position_ids (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
Indices of positions of each input sequence tokens in the position embeddings.
**modal_position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``:
modal_position_ids (``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``, `optional`):
Indices of positions of each input sequence tokens in the position embeddings for the non-text modality.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
head_mask (``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``, `optional`):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
inputs_embeds (``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``, `optional`):
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
encoder_hidden_states (``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
is configured as a decoder.
**encoder_attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
encoder_attention_mask (``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
is used in the cross-attention if the model is configured as a decoder.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@add_start_docstrings(
"The bare MMBT Model outputting raw hidden-states without any specific head on top.",
MMBT_START_DOCSTRING,
MMBT_INPUTS_DOCSTRING,
"The bare MMBT Model outputting raw hidden-states without any specific head on top.", MMBT_START_DOCSTRING,
)
class MMBTModel(nn.Module, ModuleUtilsMixin):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the output of the last layer of the model.
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
**hidden_states**: (`optional`, returned when ``output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
# For example purposes. Not runnable.
transformer = BertModel.from_pretrained('bert-base-uncased')
encoder = ImageEncoder(args)
mmbt = MMBTModel(config, transformer, encoder)
"""
def __init__(self, config, transformer, encoder):
super().__init__()
self.config = config
self.transformer = transformer
self.modal_encoder = ModalEmbeddings(config, encoder, transformer.embeddings)
@add_start_docstrings_to_callable(MMBT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_modal,
@ -200,8 +180,25 @@ class MMBTModel(nn.Module, ModuleUtilsMixin):
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Returns:
Examples::
# For example purposes. Not runnable.
transformer = BertModel.from_pretrained('bert-base-uncased')
encoder = ImageEncoder(args)
mmbt = MMBTModel(config, transformer, encoder)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -258,16 +255,23 @@ class MMBTModel(nn.Module, ModuleUtilsMixin):
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = encoder_outputs[0]
pooled_output = self.transformer.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
if return_tuple:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self):
return self.embeddings.word_embeddings

Просмотреть файл

@ -24,6 +24,8 @@ import logging
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
@ -34,12 +36,29 @@ from transformers.modeling_bert import BertIntermediate
from .activations import gelu, gelu_new, swish
from .configuration_mobilebert import MobileBertConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "MobileBertConfig"
_TOKENIZER_FOR_DOC = "MobileBertTokenizer"
MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = ["google/mobilebert-uncased"]
@ -528,9 +547,10 @@ class MobileBertEncoder(nn.Module):
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_tuple=False,
):
all_hidden_states = ()
all_attentions = ()
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -552,12 +572,11 @@ class MobileBertEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
if return_tuple:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class MobileBertPooler(nn.Module):
@ -660,6 +679,39 @@ class MobileBertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
@dataclass
class MobileBertForPretrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.MobileBertForPretrainingModel`.
Args:
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False
continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
prediction_logits: torch.FloatTensor
seq_relationship_logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
MOBILEBERT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
@ -714,6 +766,12 @@ MOBILEBERT_INPUTS_DOCSTRING = r"""
is used in the cross-attention if the model is configured as a decoder.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -750,7 +808,12 @@ class MobileBertModel(MobileBertPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/mobilebert-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/mobilebert-uncased",
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -763,38 +826,13 @@ class MobileBertModel(MobileBertPreTrainedModel):
encoder_attention_mask=None,
output_hidden_states=None,
output_attentions=None,
return_tuple=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during pre-training.
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -847,13 +885,20 @@ class MobileBertModel(MobileBertPreTrainedModel):
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + encoder_outputs[
1:
] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
if return_tuple:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
@ -895,6 +940,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MobileBertForPretrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -907,6 +953,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
@ -920,25 +967,6 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
``0`` indicates sequence B is a continuation of sequence A,
``1`` indicates sequence B is a random sequence.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs:
loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False
continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -954,6 +982,8 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
>>> prediction_scores, seq_relationship_scores = outputs[:2]
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.mobilebert(
input_ids,
attention_mask=attention_mask,
@ -963,21 +993,29 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
outputs = (prediction_scores, seq_relationship_score,) + outputs[
2:
] # add hidden states and attention if they are here
total_loss = None
if labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
outputs = (total_loss,) + outputs
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
if return_tuple:
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return MobileBertForPretrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
@ -1016,7 +1054,12 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/mobilebert-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/mobilebert-uncased",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1030,6 +1073,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs
):
r"""
@ -1040,24 +1084,6 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
in ``[0, ..., config.vocab_size]``
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs:
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
if "masked_lm_labels" in kwargs:
warnings.warn(
@ -1065,6 +1091,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
FutureWarning,
)
labels = kwargs.pop("masked_lm_labels")
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.mobilebert(
input_ids,
@ -1077,19 +1104,27 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
if return_tuple:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class MobileBertOnlyNSPHead(nn.Module):
@ -1116,6 +1151,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -1127,6 +1163,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1136,22 +1173,6 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
``1`` indicates sequence B is a random sequence.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
Next sequence prediction (classification) loss.
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -1167,6 +1188,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
>>> loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.mobilebert(
input_ids,
@ -1177,19 +1199,27 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
pooled_output = outputs[1]
seq_relationship_score = self.cls(pooled_output)
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
next_sentence_loss = None
if next_sentence_label is not None:
loss_fct = CrossEntropyLoss()
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
outputs = (next_sentence_loss,) + outputs
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
if return_tuple:
output = (seq_relationship_score,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
return NextSentencePredictorOutput(
loss=next_sentence_loss,
logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1208,7 +1238,12 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/mobilebert-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/mobilebert-uncased",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1220,6 +1255,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1227,24 +1263,8 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.mobilebert(
input_ids,
@ -1255,11 +1275,13 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
@ -1268,8 +1290,14 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1288,7 +1316,12 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/mobilebert-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/mobilebert-uncased",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1301,6 +1334,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1311,27 +1345,8 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.mobilebert(
input_ids,
@ -1342,6 +1357,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -1351,7 +1367,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + outputs[2:]
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
@ -1367,9 +1383,18 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
if return_tuple:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1388,7 +1413,12 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/mobilebert-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/mobilebert-uncased",
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1400,33 +1430,15 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@ -1448,6 +1460,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
pooled_output = outputs[1]
@ -1456,14 +1469,18 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
if return_tuple:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1483,7 +1500,12 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(MOBILEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/mobilebert-uncased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/mobilebert-uncased",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1495,30 +1517,14 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.mobilebert(
input_ids,
@ -1529,6 +1535,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -1536,7 +1543,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
@ -1549,6 +1556,11 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)

Просмотреть файл

@ -21,6 +21,8 @@ import logging
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
@ -28,7 +30,14 @@ from torch.nn import CrossEntropyLoss
from .activations import gelu_new, swish
from .configuration_openai import OpenAIGPTConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_outputs import BaseModelOutput, CausalLMOutput
from .modeling_utils import (
Conv1D,
PreTrainedModel,
@ -40,6 +49,7 @@ from .modeling_utils import (
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "OpenAIGPTConfig"
_TOKENIZER_FOR_DOC = "OpenAIGPTTokenizer"
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -277,6 +287,41 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
@dataclass
class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
"""
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
Language modeling loss.
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
Multiple choice classification loss.
lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
lm_loss: Optional[torch.FloatTensor]
mc_loss: Optional[torch.FloatTensor]
lm_logits: torch.FloatTensor
mc_logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
OPENAI_GPT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
@ -326,6 +371,10 @@ OPENAI_GPT_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -358,7 +407,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.h[layer].attn.prune_heads(heads)
@add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="openai-gpt")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="openai-gpt",
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -369,28 +423,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.OpenAIGPTConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -441,8 +480,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),)
all_attentions = ()
all_hidden_states = ()
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, block in enumerate(self.h):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
@ -452,16 +491,17 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
if output_attentions:
all_attentions = all_attentions + (outputs[1],)
hidden_states = hidden_states.view(*output_shape)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states.view(*output_shape),)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last hidden state, (all hidden states), (all attentions)
if return_tuple:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions,
)
@add_start_docstrings(
@ -481,7 +521,12 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
return self.lm_head
@add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="openai-gpt")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="openai-gpt",
output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -493,6 +538,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -501,29 +547,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
Indices are selected in ``[-100, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.OpenAIGPTConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
Language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
@ -533,11 +559,12 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
outputs = (lm_logits,) + transformer_outputs[1:]
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
@ -545,9 +572,17 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), lm_logits, (all hidden states), (all attentions)
if return_tuple:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutput(
loss=loss,
logits=lm_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
@ -573,6 +608,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
return self.lm_head
@add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -586,6 +622,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
mc_labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs
):
r"""
@ -606,30 +643,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
Used to hide legacy arguments that have been deprecated.
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.OpenAIGPTConfig`) and inputs:
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
Language modeling loss.
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
Multiple choice classification loss.
lm_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -647,8 +660,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
outputs = model(input_ids, mc_token_ids=mc_token_ids)
lm_prediction_scores, mc_prediction_scores = outputs[:2]
"""
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if "lm_labels" in kwargs:
warnings.warn(
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
@ -666,22 +679,35 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
lm_loss = None
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
outputs = (loss,) + outputs
lm_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
mc_loss = None
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
outputs = (loss,) + outputs
mc_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return outputs # (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions)
if return_tuple:
output = (lm_logits, mc_logits) + transformer_outputs[1:]
if mc_loss is not None:
output = (mc_loss,) + output
return ((lm_loss,) + output) if lm_loss is not None else output
return OpenAIGPTDoubleHeadsModelOutput(
lm_loss=lm_loss,
mc_loss=mc_loss,
lm_logits=lm_logits,
mc_logits=mc_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

Просмотреть файл

@ -0,0 +1,559 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
from .file_utils import ModelOutput
@dataclass
class BaseModelOutput(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class BaseModelOutputWithPooling(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during pre-training.
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor
pooler_output: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class BaseModelOutputWithPast(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
``past_key_values`` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class Seq2SeqModelOutput(ModelOutput):
"""
Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
decoding.
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
If `decoder_past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
last_hidden_state: torch.FloatTensor
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class CausalLMOutput(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Language modeling loss (for next-token prediction).
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class CausalLMOutputWithPast(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Language modeling loss (for next-token prediction).
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
``past_key_values`` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MaskedLMOutput(ModelOutput):
"""
Base class for masked language models outputs.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Masked languaged modeling (MLM) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class Seq2SeqLMOutput(ModelOutput):
"""
Base class for sequence-to-sequence language models outputs.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Languaged modeling loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class NextSentencePredictorOutput(ModelOutput):
"""
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
Next sequence prediction (classification) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class SequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class Seq2SeqSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sequence-to-sequence sentence classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class TokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class QuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of question answering models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
start_logits: torch.FloatTensor
end_logits: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of sequence-to-sequence question answering models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
loss: Optional[torch.FloatTensor]
start_logits: torch.FloatTensor
end_logits: torch.FloatTensor
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None

Просмотреть файл

@ -36,11 +36,13 @@ from .file_utils import (
add_start_docstrings,
add_start_docstrings_to_callable,
)
from .modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "ReformerConfig"
_TOKENIZER_FOR_DOC = "ReformerTokenizer"
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -1493,6 +1495,10 @@ REFORMER_INPUTS_DOCSTRING = r"""
For more information, see `num_hashes` in :class:`transformers.ReformerConfig`.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -1528,7 +1534,12 @@ class ReformerModel(ReformerPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/reformer-crime-and-punishment")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/reformer-crime-and-punishment",
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1539,29 +1550,13 @@ class ReformerModel(ReformerPreTrainedModel):
num_hashes=None,
output_hidden_states=None,
output_attentions=None,
return_tuple=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -1628,13 +1623,12 @@ class ReformerModel(ReformerPreTrainedModel):
if must_pad_to_match_chunk_length:
sequence_output = sequence_output[:, :orig_sequence_length]
outputs = (sequence_output,)
# TODO(PVP): Replace by named tuple after namedtuples are introduced in the library.
if output_hidden_states is True:
outputs = outputs + (encoder_outputs.all_hidden_states,)
if output_attentions is True:
outputs = outputs + (encoder_outputs.all_attentions,)
return outputs
hidden_states = encoder_outputs.all_hidden_states if output_hidden_states else None
attentions = encoder_outputs.all_attentions if output_attentions else None
if return_tuple:
return tuple(v for v in [sequence_output, hidden_states, attentions] if v is not None)
return BaseModelOutput(last_hidden_state=sequence_output, hidden_states=hidden_states, attentions=attentions)
def _pad_to_mult_of_chunk_length(
self,
@ -1712,7 +1706,12 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
pass
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/reformer-crime-and-punishment")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/reformer-crime-and-punishment",
output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1724,6 +1723,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
labels=None,
output_hidden_states=None,
output_attentions=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1731,25 +1731,8 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
Indices should be in :obj:`[-100, 0, ..., config.vocab_size - 1]`.
All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss (cross entropy).
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
reformer_outputs = self.reformer(
input_ids,
@ -1760,12 +1743,13 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
num_hashes=num_hashes,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_tuple=return_tuple,
)
sequence_output = reformer_outputs[0]
logits = self.lm_head(sequence_output)
outputs = (logits,) + reformer_outputs[1:]
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
@ -1773,8 +1757,17 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
outputs = (loss,) + outputs
return outputs # (lm_loss), lm_logits, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + reformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutput(
loss=loss,
logits=logits,
hidden_states=reformer_outputs.hidden_states,
attentions=reformer_outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
# TODO(PVP): Add smart caching
@ -1806,7 +1799,12 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
pass
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/reformer-crime-and-punishment")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/reformer-crime-and-punishment",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1818,31 +1816,15 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
labels=None,
output_hidden_states=None,
output_attentions=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss (cross entropy).
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
reformer_outputs = self.reformer(
input_ids,
@ -1853,18 +1835,27 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
num_hashes=num_hashes,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_tuple=return_tuple,
)
sequence_output = reformer_outputs[0]
logits = self.lm_head(sequence_output)
outputs = (logits,) + reformer_outputs[1:]
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs # (mlm_loss), lm_logits, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + reformer_outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=logits,
hidden_states=reformer_outputs.hidden_states,
attentions=reformer_outputs.attentions,
)
@add_start_docstrings(
@ -1889,7 +1880,12 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
pass
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/reformer-crime-and-punishment")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/reformer-crime-and-punishment",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1902,6 +1898,7 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
end_positions=None,
output_hidden_states=None,
output_attentions=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1912,26 +1909,8 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
reformer_outputs = self.reformer(
input_ids,
@ -1942,6 +1921,7 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
num_hashes=num_hashes,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_tuple=return_tuple,
)
sequence_output = reformer_outputs[0]
@ -1951,8 +1931,7 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + reformer_outputs[1:]
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
@ -1968,6 +1947,15 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
if return_tuple:
output = (start_logits, end_logits) + reformer_outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=reformer_outputs.hidden_states,
attentions=reformer_outputs.attentions,
)

Просмотреть файл

@ -26,10 +26,18 @@ from torch.nn import CrossEntropyLoss, MSELoss
from .configuration_roberta import RobertaConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu
from .modeling_outputs import (
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -133,6 +141,10 @@ ROBERTA_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -179,7 +191,12 @@ class RobertaForMaskedLM(BertPreTrainedModel):
return self.lm_head.decoder
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="roberta-base")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="roberta-base",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -191,6 +208,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs
):
r"""
@ -201,24 +219,6 @@ class RobertaForMaskedLM(BertPreTrainedModel):
in ``[0, ..., config.vocab_size]``
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
if "masked_lm_labels" in kwargs:
warnings.warn(
@ -227,6 +227,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
)
labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.roberta(
input_ids,
@ -237,18 +238,26 @@ class RobertaForMaskedLM(BertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
if return_tuple:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class RobertaLMHead(nn.Module):
@ -295,7 +304,12 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="roberta-base")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="roberta-base",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -307,6 +321,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -314,25 +329,9 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
@ -342,11 +341,12 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:]
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
@ -355,9 +355,14 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
@add_start_docstrings(
@ -379,7 +384,12 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="roberta-base")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="roberta-base",
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -391,33 +401,15 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
loss (:obj:`torch.FloatTensor`` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@ -439,6 +431,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
pooled_output = outputs[1]
@ -446,14 +439,18 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
if return_tuple:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
@add_start_docstrings(
@ -476,7 +473,12 @@ class RobertaForTokenClassification(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="roberta-base")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="roberta-base",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -488,30 +490,14 @@ class RobertaForTokenClassification(BertPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.roberta(
input_ids,
@ -522,6 +508,7 @@ class RobertaForTokenClassification(BertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -529,8 +516,7 @@ class RobertaForTokenClassification(BertPreTrainedModel):
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
@ -543,9 +529,14 @@ class RobertaForTokenClassification(BertPreTrainedModel):
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
class RobertaClassificationHead(nn.Module):
@ -586,7 +577,12 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="roberta-base")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="roberta-base",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -599,6 +595,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -609,27 +606,8 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.roberta(
input_ids,
@ -640,6 +618,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -649,7 +628,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + outputs[2:]
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
@ -665,9 +644,18 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
if return_tuple:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def create_position_ids_from_input_ids(input_ids, padding_idx):

Просмотреть файл

@ -27,12 +27,20 @@ from torch import nn
from torch.nn import CrossEntropyLoss
from .configuration_t5 import T5Config
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import (
DUMMY_INPUTS,
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"
####################################################
@ -667,6 +675,7 @@ class T5Stack(T5PreTrainedModel):
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
@ -674,6 +683,7 @@ class T5Stack(T5PreTrainedModel):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -704,6 +714,9 @@ class T5Stack(T5PreTrainedModel):
else:
mask_seq_length = seq_length
if use_cache is True:
assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self)
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
@ -726,9 +739,9 @@ class T5Stack(T5PreTrainedModel):
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
present_key_value_states = ()
all_hidden_states = ()
all_attentions = ()
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
position_bias = None
encoder_decoder_position_bias = None
@ -761,7 +774,8 @@ class T5Stack(T5PreTrainedModel):
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3]
# append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,)
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now
@ -773,15 +787,18 @@ class T5Stack(T5PreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache is True:
assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self)
outputs = outputs + (present_key_value_states,)
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (presents,) (all hidden states), (all attentions)
if return_tuple:
return tuple(
v
for v in [hidden_states, present_key_value_states, all_hidden_states, all_attentions]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
T5_START_DOCSTRING = r"""
@ -849,6 +866,10 @@ T5_INPUTS_DOCSTRING = r"""
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -894,6 +915,7 @@ class T5Model(T5PreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -908,42 +930,25 @@ class T5Model(T5PreTrainedModel):
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `decoder_past_key_value_states` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Example::
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
>>> from transformers import T5Tokenizer, T5Model
Example::
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = T5Model.from_pretrained('t5-small')
>>> from transformers import T5Tokenizer, T5Model
>>> input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1
>>> outputs = model(input_ids=input_ids)
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = T5Model.from_pretrained('t5-small')
>>> input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1
>>> outputs = model(input_ids=input_ids)
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
@ -954,6 +959,13 @@ class T5Model(T5PreTrainedModel):
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
elif not return_tuple and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
hidden_states = encoder_outputs[0]
@ -984,13 +996,24 @@ class T5Model(T5PreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
if use_cache is True:
past = ((encoder_outputs, decoder_outputs[1]),)
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None
if return_tuple:
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + encoder_outputs
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_past_key_values=past,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
@ -1031,6 +1054,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
return self.decoder
@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -1046,6 +1070,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
**kwargs
):
r"""
@ -1058,27 +1083,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss (cross entropy).
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
If `past_key_value_states` is used only the last prediction_scores of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -1105,6 +1109,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
@ -1116,6 +1121,13 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
elif not return_tuple and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
hidden_states = encoder_outputs[0]
@ -1145,28 +1157,38 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
# insert decoder past at right place
# to speed up decoding
if use_cache is True:
past = ((encoder_outputs, decoder_outputs[1]),)
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
sequence_output = decoder_outputs[0]
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim ** -0.5)
lm_logits = self.lm_head(sequence_output)
decoder_outputs = (lm_logits,) + decoder_outputs[1:] # Add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
decoder_outputs = (loss,) + decoder_outputs
return decoder_outputs + encoder_outputs
past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None
if return_tuple:
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
decoder_past_key_values=past,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"

Просмотреть файл

@ -20,20 +20,22 @@
import logging
from typing import Optional
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .configuration_transfo_xl import TransfoXLConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import ModelOutput, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax
from .modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "TransfoXLConfig"
_TOKENIZER_FOR_DOC = "TransfoXLTokenizer"
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -590,6 +592,73 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
return embeddings.cutoffs
@dataclass
class TransfoXLModelOutput(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor
mems: List[torch.FloatTensor]
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class TransfoXLLMHeadModelOutput(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
Language modeling loss (for next-token prediction).
losses (:obj:`torch.FloatTensor` of shape `(batch_size, sequence_length-1)`, `optional`, returned when ``labels`` is provided)
Language modeling losses (not reduced).
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token after SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
losses: Optional[torch.FloatTensor]
prediction_scores: torch.FloatTensor
mems: List[torch.FloatTensor]
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
TRANSFO_XL_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
@ -626,6 +695,10 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -751,7 +824,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return new_mems
@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="transfo-xl-wt103")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="transfo-xl-wt103",
output_type=TransfoXLModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -760,32 +838,13 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.TransfoXLConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
@ -841,7 +900,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
]
hids = []
attentions = []
attentions = [] if output_attentions else None
if self.attn_type == 0: # default
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype)
if self.clamp_len > 0:
@ -872,19 +931,24 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
new_mems = self._update_mems(hids, mems, mlen, qlen)
# We transpose back here to shape [bsz, len, hidden_dim]
outputs = [core_out.transpose(0, 1).contiguous(), new_mems]
if output_hidden_states:
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids.append(core_out)
hids = list(t.transpose(0, 1).contiguous() for t in hids)
outputs.append(hids)
hids = tuple(t.transpose(0, 1).contiguous() for t in hids)
else:
hids = None
if output_attentions:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs.append(attentions)
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
# We transpose back here to shape [bsz, len, hidden_dim]
core_out = core_out.transpose(0, 1).contiguous()
return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
if return_tuple:
return tuple(v for v in [core_out, new_mems, hids, attentions] if v is not None)
return TransfoXLModelOutput(
last_hidden_state=core_out, mems=new_mems, hidden_states=hids, attentions=attentions,
)
@add_start_docstrings(
@ -936,7 +1000,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
return self.transformer.init_mems(bsz)
@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="transfo-xl-wt103")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="transfo-xl-wt103",
output_type=TransfoXLLMHeadModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -946,6 +1015,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -954,29 +1024,8 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
Indices are selected in ``[-100, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.TransfoXLConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(batch_size, sequence_length-1)`, `optional`, returned when ``labels`` is provided)
Language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None:
bsz, tgt_len = input_ids.size(0), input_ids.size(1)
elif inputs_embeds is not None:
@ -991,6 +1040,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
last_hidden = transformer_outputs[0]
@ -998,14 +1048,20 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
outputs = transformer_outputs[1:]
softmax_output = self.crit(pred_hid, labels)
if labels is None:
softmax_output = softmax_output.view(bsz, tgt_len, -1)
outputs = [softmax_output] + outputs
else:
softmax_output = softmax_output.view(bsz, tgt_len - 1)
outputs = [softmax_output, None] + outputs
prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else ()
loss = softmax_output.view(bsz, tgt_len - 1) if labels is not None else None
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
if return_tuple:
output = (prediction_scores,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TransfoXLLMHeadModelOutput(
losses=loss,
prediction_scores=prediction_scores,
mems=transformer_outputs.mems,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def get_output_embeddings(self):
""" Double-check if you are using adaptive softmax.

Просмотреть файл

@ -17,6 +17,7 @@
import inspect
import logging
import os
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple
import torch
@ -31,6 +32,7 @@ from .file_utils import (
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
WEIGHTS_NAME,
ModelOutput,
cached_path,
hf_bucket_url,
is_remote_url,
@ -941,6 +943,35 @@ class PoolerAnswerClass(nn.Module):
return x
@dataclass
class SquadHeadOutput(ModelOutput):
"""
Base class for outputs of question answering models using a :obj:`SquadHead`.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top config.start_n_top start token possibilities (beam-search).
end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the ``is_impossible`` label of the answers.
"""
loss: Optional[torch.FloatTensor] = None
start_top_log_probs: Optional[torch.FloatTensor] = None
start_top_index: Optional[torch.LongTensor] = None
end_top_log_probs: Optional[torch.FloatTensor] = None
end_top_index: Optional[torch.LongTensor] = None
cls_logits: Optional[torch.FloatTensor] = None
class SQuADHead(nn.Module):
r""" A SQuAD head inspired by XLNet.
@ -992,10 +1023,15 @@ class SQuADHead(nn.Module):
self.answer_class = PoolerAnswerClass(config)
def forward(
self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
self,
hidden_states,
start_positions=None,
end_positions=None,
cls_index=None,
is_impossible=None,
p_mask=None,
return_tuple=False,
):
outputs = ()
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
if start_positions is not None and end_positions is not None:
@ -1021,7 +1057,7 @@ class SQuADHead(nn.Module):
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
total_loss += cls_loss * 0.5
outputs = (total_loss,) + outputs
return (total_loss,) if return_tuple else SquadHeadOutput(loss=total_loss)
else:
# during inference, compute the end logits based on beam search
@ -1051,11 +1087,16 @@ class SQuADHead(nn.Module):
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
# or (if labels are provided) (total_loss,)
return outputs
if return_tuple:
return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
else:
return SquadHeadOutput(
start_top_log_probs=start_top_log_probs,
start_top_index=start_top_index,
end_top_log_probs=end_top_log_probs,
end_top_index=end_top_index,
cls_logits=cls_logits,
)
class SequenceSummary(nn.Module):

Просмотреть файл

@ -19,6 +19,8 @@
import itertools
import logging
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import torch
@ -28,7 +30,20 @@ from torch.nn import functional as F
from .activations import gelu
from .configuration_xlm import XLMConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from .modeling_utils import (
PreTrainedModel,
SequenceSummary,
@ -40,6 +55,7 @@ from .modeling_utils import (
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "XLMConfig"
_TOKENIZER_FOR_DOC = "XLMTokenizer"
XLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -240,6 +256,47 @@ class XLMPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0)
@dataclass
class XLMForQuestionAnsweringOutput(ModelOutput):
"""
Base class for outputs of question answering models using a :obj:`SquadHead`.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top config.start_n_top start token possibilities (beam-search).
end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the ``is_impossible`` label of the answers.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
start_top_log_probs: Optional[torch.FloatTensor] = None
start_top_index: Optional[torch.LongTensor] = None
end_top_log_probs: Optional[torch.FloatTensor] = None
end_top_index: Optional[torch.LongTensor] = None
cls_logits: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
XLM_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
@ -306,6 +363,10 @@ XLM_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -397,7 +458,12 @@ class XLMModel(XLMPreTrainedModel):
self.attentions[layer].prune_heads(heads)
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlm-mlm-en-2048")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="xlm-mlm-en-2048",
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -411,28 +477,13 @@ class XLMModel(XLMPreTrainedModel):
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if input_ids is not None:
bs, slen = input_ids.size()
@ -502,8 +553,8 @@ class XLMModel(XLMPreTrainedModel):
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# transformer layers
hidden_states = ()
attentions = ()
hidden_states = () if output_hidden_states else None
attentions = () if output_attentions else None
for i in range(self.n_layers):
if output_hidden_states:
hidden_states = hidden_states + (tensor,)
@ -542,12 +593,9 @@ class XLMModel(XLMPreTrainedModel):
# move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1)
outputs = (tensor,)
if output_hidden_states:
outputs = outputs + (hidden_states,)
if output_attentions:
outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions)
if return_tuple:
return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
return BaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
class XLMPredLayer(nn.Module):
@ -623,7 +671,12 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
return {"input_ids": input_ids, "langs": langs}
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlm-mlm-en-2048")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="xlm-mlm-en-2048",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -638,6 +691,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -646,25 +700,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
Indices are selected in ``[-100, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
Language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
@ -677,13 +715,21 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
output = transformer_outputs[0]
outputs = self.pred_layer(output, labels)
outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
outputs = self.pred_layer(output, labels) # (loss, logits) or (logits,) depending on if labels are provided.
return outputs
if return_tuple:
return outputs + transformer_outputs[1:]
return MaskedLMOutput(
loss=outputs[0] if labels is not None else None,
logits=outputs[0] if labels is None else outputs[1],
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
@ -702,7 +748,12 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlm-mlm-en-2048")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="xlm-mlm-en-2048",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -717,6 +768,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -724,25 +776,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
@ -755,13 +791,13 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
output = transformer_outputs[0]
logits = self.sequence_summary(output)
outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
@ -770,9 +806,17 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs
if return_tuple:
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
@ -790,7 +834,12 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlm-mlm-en-2048")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="xlm-mlm-en-2048",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -806,6 +855,7 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -816,27 +866,9 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
@ -849,6 +881,7 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = transformer_outputs[0]
@ -858,10 +891,7 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (
start_logits,
end_logits,
)
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
@ -877,11 +907,18 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
if return_tuple:
output = (start_logits, end_logits) + transformer_outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output
return outputs
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
@ -899,6 +936,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=XLMForQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -917,6 +955,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
p_mask=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -936,30 +975,6 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
1.0 means token should be masked. 0.0 mean token is not masked.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top config.start_n_top start token possibilities (beam-search).
end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the ``is_impossible`` label of the answers.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Example::
@ -976,6 +991,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
>>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
>>> loss = outputs[0]
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
@ -988,6 +1005,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
output = transformer_outputs[0]
@ -999,11 +1017,22 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
cls_index=cls_index,
is_impossible=is_impossible,
p_mask=p_mask,
return_tuple=return_tuple,
)
outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
if return_tuple:
return outputs + transformer_outputs[1:]
return outputs
return XLMForQuestionAnsweringOutput(
loss=outputs.loss,
start_top_log_probs=outputs.start_top_log_probs,
start_top_index=outputs.start_top_index,
end_top_log_probs=outputs.end_top_log_probs,
end_top_index=outputs.end_top_index,
cls_logits=outputs.cls_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
@ -1023,7 +1052,12 @@ class XLMForTokenClassification(XLMPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlm-mlm-en-2048")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="xlm-mlm-en-2048",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1034,33 +1068,19 @@ class XLMForTokenClassification(XLMPreTrainedModel):
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
@ -1070,8 +1090,10 @@ class XLMForTokenClassification(XLMPreTrainedModel):
lengths=lengths,
cache=cache,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -1079,7 +1101,7 @@ class XLMForTokenClassification(XLMPreTrainedModel):
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
@ -1092,6 +1114,11 @@ class XLMForTokenClassification(XLMPreTrainedModel):
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)
if return_tuple:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)

Просмотреть файл

@ -55,6 +55,10 @@ XLM_ROBERTA_START_DOCSTRING = r"""
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""

Просмотреть файл

@ -18,6 +18,8 @@
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
from torch import nn
@ -26,12 +28,19 @@ from torch.nn import functional as F
from .activations import gelu_new, swish
from .configuration_xlnet import XLNetConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
replace_return_docstrings,
)
from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary
logger = logging.getLogger(__name__)
_CONFIG_FOR_DOC = "XLNetConfig"
_TOKENIZER_FOR_DOC = "XLNetTokenizer"
XLNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -554,6 +563,264 @@ class XLNetPreTrainedModel(PreTrainedModel):
module.mask_emb.data.normal_(mean=0.0, std=self.config.initializer_range)
@dataclass
class XLNetModelOutput(ModelOutput):
"""
Output type of :class:`~transformers.XLNetModel`.
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
``num_predict`` corresponds to ``sequence_length``.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor
mems: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class XLNetLMHeadModelOutput(ModelOutput):
"""
Output type of :class:`~transformers.XLNetModel`.
Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
Language modeling loss (for next-token prediction).
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
``num_predict`` corresponds to ``sequence_length``.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
mems: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class XLNetForSequenceClassificationOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
mems: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class XLNetForTokenClassificationOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
mems: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class XLNetForMultipleChoiceOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor
mems: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class XLNetForQuestionAnsweringSimpleOutput(ModelOutput):
"""
Base class for outputs of question answering models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor]
start_logits: torch.FloatTensor
end_logits: torch.FloatTensor
mems: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class XLNetForQuestionAnsweringOutput(ModelOutput):
"""
Base class for outputs of question answering models using a :obj:`SquadHead`.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top config.start_n_top start token possibilities (beam-search).
end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the ``is_impossible`` label of the answers.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
start_top_log_probs: Optional[torch.FloatTensor] = None
start_top_index: Optional[torch.LongTensor] = None
end_top_log_probs: Optional[torch.FloatTensor] = None
end_top_index: Optional[torch.LongTensor] = None
cls_logits: Optional[torch.FloatTensor] = None
mems: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
XLNET_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
@ -622,6 +889,10 @@ XLNET_INPUTS_DOCSTRING = r"""
If `use_cache` is True, `mems` are returned and can be used to speed up decoding (see `mems`). Defaults to `True`.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
"""
@ -751,7 +1022,12 @@ class XLNetModel(XLNetPreTrainedModel):
return pos_emb
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="xlnet-base-cased",
output_type=XLNetModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -766,33 +1042,13 @@ class XLNetModel(XLNetPreTrainedModel):
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
`num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict` corresponds to `sequence_length`.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
@ -920,8 +1176,8 @@ class XLNetModel(XLNetPreTrainedModel):
if mems is None:
mems = [None] * len(self.layer)
attentions = []
hidden_states = []
attentions = [] if output_attentions else None
hidden_states = [] if output_hidden_states else None
for i, layer_module in enumerate(self.layer):
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
# cache new mems
@ -952,17 +1208,18 @@ class XLNetModel(XLNetPreTrainedModel):
output = self.dropout(output_g if output_g is not None else output_h)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs = (output.permute(1, 0, 2).contiguous(),)
output = output.permute(1, 0, 2).contiguous()
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
outputs = outputs + (new_mems,)
# TODO Teven: fix this test to only use use_cache.
if not (self.mem_len is not None and self.mem_len > 0 and use_cache is True):
new_mems = None
if output_hidden_states:
if output_g is not None:
hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
else:
hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
outputs = outputs + (hidden_states,)
if output_attentions:
if target_mapping is not None:
# when target_mapping is provided, there are 2-tuple of attentions
@ -971,9 +1228,13 @@ class XLNetModel(XLNetPreTrainedModel):
)
else:
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs = outputs + (attentions,)
return outputs # outputs, (new_mems), (hidden_states), (attentions)
if return_tuple:
return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)
return XLNetModelOutput(
last_hidden_state=output, mems=new_mems, hidden_states=hidden_states, attentions=attentions
)
@add_start_docstrings(
@ -1029,6 +1290,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return inputs
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=XLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -1040,10 +1302,11 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
labels=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_predict)`, `optional`, defaults to :obj:`None`):
@ -1055,27 +1318,6 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
computed for labels in ``[0, ..., config.vocab_size]``
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
Language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
`num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict` corresponds to `sequence_length`.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
@ -1108,6 +1350,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
loss, next_token_logits = outputs[:2] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
@ -1121,19 +1365,28 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
logits = self.lm_loss(transformer_outputs[0])
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
loss = None
if labels is not None:
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
outputs = (loss,) + outputs
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
if return_tuple:
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return XLNetLMHeadModelOutput(
loss=loss,
logits=logits,
mems=transformer_outputs.mems,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
@ -1153,7 +1406,12 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="xlnet-base-cased",
output_type=XLNetForSequenceClassificationOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1165,10 +1423,11 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
labels=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`)
@ -1176,29 +1435,9 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
@ -1212,14 +1451,14 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
output = transformer_outputs[0]
output = self.sequence_summary(output)
logits = self.logits_proj(output)
outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
@ -1228,9 +1467,18 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
if return_tuple:
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return XLNetForSequenceClassificationOutput(
loss=loss,
logits=logits,
mems=transformer_outputs.mems,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
@ -1249,7 +1497,12 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="xlnet-base-cased",
output_type=XLNetForTokenClassificationOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1261,39 +1514,19 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
labels=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:(batch_size, config.num_labels)`):
Classification scores (before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.transformer(
input_ids,
@ -1308,13 +1541,14 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[1:] # Keep mems, hidden states, attentions if there are in it
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
@ -1327,9 +1561,18 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
if return_tuple:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return XLNetForTokenClassificationOutput(
loss=loss,
logits=logits,
mems=outputs.mems,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1348,7 +1591,12 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="xlnet-base-cased",
output_type=XLNetForMultipleChoiceOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1360,41 +1608,19 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
target_mapping=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
labels=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
loss (:obj:`torch.FloatTensor`` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
@ -1420,6 +1646,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
output = transformer_outputs[0]
@ -1427,16 +1654,23 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
output = self.sequence_summary(output)
logits = self.logits_proj(output)
reshaped_logits = logits.view(-1, num_choices)
outputs = (reshaped_logits,) + transformer_outputs[
1:
] # Keep mems, hidden states, attentions if there are in it
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels.view(-1))
outputs = (loss,) + outputs
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
if return_tuple:
output = (reshaped_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return XLNetForMultipleChoiceOutput(
loss=loss,
logits=reshaped_logits,
mems=transformer_outputs.mems,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@add_start_docstrings(
@ -1455,7 +1689,12 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="xlnet-base-cased",
output_type=XLNetForQuestionAnsweringSimpleOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
@ -1467,11 +1706,12 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
start_positions=None,
end_positions=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1482,31 +1722,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
outputs = self.transformer(
input_ids,
@ -1521,6 +1738,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
sequence_output = outputs[0]
@ -1530,7 +1748,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = (start_logits, end_logits,) + outputs[2:]
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
@ -1546,9 +1764,19 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (mems), (hidden_states), (attentions)
if return_tuple:
output = (start_logits, end_logits) + outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output
return XLNetForQuestionAnsweringSimpleOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
mems=outputs.mems,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
@ -1570,6 +1798,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@replace_return_docstrings(output_type=XLNetForQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
@ -1581,14 +1810,15 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
start_positions=None,
end_positions=None,
is_impossible=None,
cls_index=None,
p_mask=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_tuple=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1608,50 +1838,24 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
1.0 means token should be masked. 0.0 mean token is not masked.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top config.start_n_top start token possibilities (beam-search).
end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the ``is_impossible`` label of the answers.
mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Example::
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
>>> from transformers import XLNetTokenizer, XLNetForQuestionAnswering
>>> import torch
Example::
>>> tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
>>> model = XLNetForQuestionAnswering.from_pretrained('xlnet-base-cased')
>>> from transformers import XLNetTokenizer, XLNetForQuestionAnswering
>>> import torch
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
>>> start_positions = torch.tensor([1])
>>> end_positions = torch.tensor([3])
>>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
>>> tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
>>> model = XLNetForQuestionAnswering.from_pretrained('xlnet-base-cased')
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
>>> start_positions = torch.tensor([1])
>>> end_positions = torch.tensor([3])
>>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
>>> loss = outputs[0]
>>> loss = outputs[0]
"""
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
@ -1665,6 +1869,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_tuple=return_tuple,
)
hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
@ -1694,7 +1899,15 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
total_loss += cls_loss * 0.5
outputs = (total_loss,) + outputs
if return_tuple:
return (total_loss,) + transformer_outputs[1:]
else:
return XLNetForQuestionAnsweringOutput(
loss=total_loss,
mems=transformer_outputs.mems,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
else:
# during inference, compute the end logits based on beam search
@ -1728,8 +1941,17 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
hidden_states, start_states=start_states, cls_index=cls_index
) # Shape (batch size,): one single `cls_logits` for each sample
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
# or (if labels are provided) (total_loss,)
return outputs
if return_tuple:
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
return outputs + transformer_outputs[1:]
else:
return XLNetForQuestionAnsweringOutput(
start_top_log_probs=start_top_log_probs,
start_top_index=start_top_index,
end_top_log_probs=end_top_log_probs,
end_top_index=end_top_index,
cls_logits=cls_logits,
mems=transformer_outputs.mems,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

Просмотреть файл

@ -220,7 +220,6 @@ class ModelTesterMixin:
def test_torchscript(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
self._create_and_check_torchscript(config, inputs_dict)
def test_torchscript_output_attentions(self):
@ -230,7 +229,6 @@ class ModelTesterMixin:
def test_torchscript_output_hidden_state(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
self._create_and_check_torchscript(config, inputs_dict)

Просмотреть файл

@ -355,6 +355,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
import tempfile
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs[0].return_tuple = True
model = T5Model(config_and_inputs[0])
with tempfile.TemporaryDirectory() as tmpdirname:
torch.onnx.export(

Просмотреть файл

@ -319,7 +319,7 @@ class TFModelTesterMixin:
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test intetgration with other keras modules
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model

Просмотреть файл

@ -347,6 +347,7 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
XLMForQuestionAnswering,
XLMForSequenceClassification,
XLMForQuestionAnsweringSimple,
XLMForTokenClassification,
)
if is_torch_available()
else ()

Просмотреть файл

@ -35,6 +35,7 @@ if is_torch_available():
XLNetForSequenceClassification,
XLNetForTokenClassification,
XLNetForQuestionAnswering,
XLNetForQuestionAnsweringSimple,
)
from transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_LIST
@ -458,6 +459,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
XLNetForTokenClassification,
XLNetForSequenceClassification,
XLNetForQuestionAnswering,
XLNetForQuestionAnsweringSimple,
XLNetForMultipleChoice,
)
if is_torch_available()