Add SequenceClassification and MultipleChoice TF models to Electra (#6227)

* Add SequenceClassification and MultipleChoice TF models to Electra

* Apply style

* Add summary_proj_to_labels to Electra config

* Finally mirroring the PT version of these models

* Apply style

* Fix Electra test
This commit is contained in:
Julien Plu 2020-08-05 15:04:27 +02:00 коммит произвёл GitHub
Родитель 376c02e9a9
Коммит 33966811bd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 281 добавлений и 12 удалений

Просмотреть файл

@ -537,8 +537,10 @@ if is_tf_available():
from .modeling_tf_electra import (
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFElectraForMaskedLM,
TFElectraForMultipleChoice,
TFElectraForPreTraining,
TFElectraForQuestionAnswering,
TFElectraForSequenceClassification,
TFElectraForTokenClassification,
TFElectraModel,
TFElectraPreTrainedModel,

Просмотреть файл

@ -77,8 +77,10 @@ from .modeling_tf_distilbert import (
)
from .modeling_tf_electra import (
TFElectraForMaskedLM,
TFElectraForMultipleChoice,
TFElectraForPreTraining,
TFElectraForQuestionAnswering,
TFElectraForSequenceClassification,
TFElectraForTokenClassification,
TFElectraModel,
)
@ -247,6 +249,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(MobileBertConfig, TFMobileBertForSequenceClassification),
(FlaubertConfig, TFFlaubertForSequenceClassification),
(XLMConfig, TFXLMForSequenceClassification),
(ElectraConfig, TFElectraForSequenceClassification),
]
)
@ -294,6 +297,7 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
(XLNetConfig, TFXLNetForMultipleChoice),
(FlaubertConfig, TFFlaubertForMultipleChoice),
(AlbertConfig, TFAlbertForMultipleChoice),
(ElectraConfig, TFElectraForMultipleChoice),
]
)

Просмотреть файл

@ -4,11 +4,19 @@ import tensorflow as tf
from transformers import ElectraConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_callable,
)
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
from .modeling_tf_utils import (
TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss,
TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFSequenceSummary,
TFTokenClassificationLoss,
get_initializer,
keras_serializable,
@ -20,6 +28,7 @@ from .tokenization_utils import BatchEncoding
logger = logging.getLogger(__name__)
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
_CONFIG_FOR_DOC = "ElectraConfig"
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/electra-small-generator",
@ -73,13 +82,7 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
super().build(input_shape)
def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
mode="embedding",
training=False,
self, input_ids, position_ids=None, token_type_ids=None, inputs_embeds=None, mode="embedding", training=False,
):
"""Get token embeddings of inputs.
Args:
@ -438,7 +441,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
def call(
self,
input_ids=None,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@ -539,7 +542,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-generator")
def call(
self,
input_ids=None,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@ -604,6 +607,225 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
return output # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
class TFElectraClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.out_proj = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
)
def call(self, inputs, **kwargs):
x = inputs[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = ACT2FN["gelu"](x) # although BERT uses tanh here, it seems Electra authors used gelu here
x = self.dropout(x)
x = self.out_proj(x)
return x
@add_start_docstrings(
"""ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
ELECTRA_START_DOCSTRING,
)
class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.electra = TFElectraMainLayer(config, name="electra")
self.classifier = TFElectraClassificationHead(config, name="classifier")
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator",
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
Returns:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`)
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
"""
outputs = self.electra(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
training=training,
)
logits = self.classifier(outputs[0])
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss = self.compute_loss(labels, logits)
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings(
"""ELECTRA Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
ELECTRA_START_DOCSTRING,
)
class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.electra = TFElectraMainLayer(config, name="electra")
self.sequence_summary = TFSequenceSummary(
config, initializer_range=config.initializer_range, name="sequence_summary"
)
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
"""
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
def call(
self,
inputs,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Return:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
labels = inputs[8] if len(inputs) > 8 else labels
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
labels = inputs.get("labels", labels)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
else None
)
outputs = self.electra(
flat_input_ids,
flat_attention_mask,
flat_token_type_ids,
flat_position_ids,
head_mask,
flat_inputs_embeds,
output_attentions,
output_hidden_states,
training=training,
)
logits = self.sequence_summary(outputs[0])
logits = self.classifier(logits)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss = self.compute_loss(labels, reshaped_logits)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
@add_start_docstrings(
"""Electra model with a token classification head on top.
@ -624,7 +846,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
def call(
self,
inputs=None,
inputs,
attention_mask=None,
token_type_ids=None,
position_ids=None,
@ -706,7 +928,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator")
def call(
self,
inputs=None,
inputs,
attention_mask=None,
token_type_ids=None,
position_ids=None,

Просмотреть файл

@ -24,10 +24,14 @@ from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers.modeling_tf_electra import (
TFElectraModel,
TFElectraForMaskedLM,
TFElectraForMultipleChoice,
TFElectraForPreTraining,
TFElectraForSequenceClassification,
TFElectraForTokenClassification,
TFElectraForQuestionAnswering,
)
@ -138,6 +142,35 @@ class TFElectraModelTester:
}
self.parent.assertListEqual(list(result["prediction_scores"].shape), [self.batch_size, self.seq_length])
def create_and_check_electra_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFElectraForSequenceClassification(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
(logits,) = model(inputs)
result = {
"logits": logits.numpy(),
}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
def create_and_check_electra_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = TFElectraForMultipleChoice(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
inputs = {
"input_ids": multiple_choice_inputs_ids,
"attention_mask": multiple_choice_input_mask,
"token_type_ids": multiple_choice_token_type_ids,
}
(logits,) = model(inputs)
result = {"logits": logits.numpy()}
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
def create_and_check_electra_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
@ -210,6 +243,14 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_sequence_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_multiple_choice(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_token_classification(*config_and_inputs)