Add DistilBertForMultipleChoice (#5032)

* Add `DistilBertForMultipleChoice`
This commit is contained in:
Sylvain Gugger 2020-06-15 18:31:41 -04:00 коммит произвёл GitHub
Родитель 36434220fc
Коммит f9f8a5312e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 144 добавлений и 1 удалений

Просмотреть файл

@ -75,6 +75,13 @@ DistilBertForSequenceClassification
:members:
DistilBertForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.DistilBertForMultipleChoice
:members:
DistilBertForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Просмотреть файл

@ -271,6 +271,7 @@ if is_torch_available():
DistilBertPreTrainedModel,
DistilBertForMaskedLM,
DistilBertModel,
DistilBertForMultipleChoice,
DistilBertForSequenceClassification,
DistilBertForQuestionAnswering,
DistilBertForTokenClassification,

Просмотреть файл

@ -75,7 +75,7 @@ class DistilBertConfig(PretrainedConfig):
The dropout probabilities used in the question answering model
:class:`~transformers.DistilBertForQuestionAnswering`.
seq_classif_dropout (:obj:`float`, optional, defaults to 0.2):
The dropout probabilities used in the sequence classification model
The dropout probabilities used in the sequence classification and the multiple choice model
:class:`~transformers.DistilBertForSequenceClassification`.
Example::

Просмотреть файл

@ -78,6 +78,7 @@ from .modeling_camembert import (
from .modeling_ctrl import CTRLLMHeadModel, CTRLModel
from .modeling_distilbert import (
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DistilBertForTokenClassification,
@ -314,6 +315,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
(LongformerConfig, LongformerForMultipleChoice),
(RobertaConfig, RobertaForMultipleChoice),
(BertConfig, BertForMultipleChoice),
(DistilBertConfig, DistilBertForMultipleChoice),
(XLNetConfig, XLNetForMultipleChoice),
(AlbertConfig, AlbertForMultipleChoice),
]

Просмотреть файл

@ -864,3 +864,111 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)
@add_start_docstrings(
"""DistilBert Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
DISTILBERT_START_DOCSTRING,
)
class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.distilbert = DistilBertModel(config)
self.pre_classifier = nn.Linear(config.dim, config.dim)
self.classifier = nn.Linear(config.dim, 1)
self.dropout = nn.Dropout(config.seq_classif_dropout)
self.init_weights()
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import DistilBertTokenizer, DistilBertForMultipleChoice
import torch
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = DistilBertForMultipleChoice.from_pretrained('distilbert-base-cased')
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0 = "It is eaten with a fork and a knife."
choice1 = "It is eaten while held in the hand."
labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True)
outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1
# the linear classifier still needs to be trained
loss, logits = outputs[:2]
"""
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)
outputs = self.distilbert(
input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
)
hidden_state = outputs[0] # (bs * num_choices, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs * num_choices, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim)
pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim)
pooled_output = self.dropout(pooled_output) # (bs * num_choices, dim)
logits = self.classifier(pooled_output) # (bs * num_choices, 1)
reshaped_logits = logits.view(-1, num_choices) # (bs, num_choices)
outputs = (reshaped_logits,) + outputs[1:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)

Просмотреть файл

@ -28,6 +28,7 @@ if is_torch_available():
DistilBertConfig,
DistilBertModel,
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
DistilBertForTokenClassification,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
@ -41,6 +42,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
(
DistilBertModel,
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DistilBertForTokenClassification,
@ -218,6 +220,25 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
)
self.check_loss_output(result)
def create_and_check_distilbert_for_multiple_choice(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = DistilBertForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
multiple_choice_inputs_ids, attention_mask=multiple_choice_input_mask, labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
@ -251,6 +272,10 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_multiple_choice(*config_and_inputs)
# @slow
# def test_model_from_pretrained(self):
# for model_name in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: