This commit is contained in:
Lysandre 2019-10-30 22:30:21 +00:00 коммит произвёл Lysandre Debut
Родитель 25a31953e8
Коммит 870320a24e
2 изменённых файлов: 219 добавлений и 31 удалений

Просмотреть файл

@ -11,6 +11,15 @@ from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'albert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
'albert-large': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
'albert-xlarge': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
'albert-xxlarge': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
}
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model."""
try:
@ -39,6 +48,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
for name, array in zip(names, arrays):
original_name = name
name = name.replace("ffn_1", "ffn")
name = name.replace("/bert/", "/albert/")
name = name.replace("ffn/intermediate/output", "ffn_output")
name = name.replace("attention_1", "attention")
name = name.replace("cls/predictions", "predictions")
@ -114,29 +124,6 @@ class AlbertAttention(BertSelfAttention):
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.num_attention_heads, self.attention_head_size)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers
self.query = prune_linear_layer(self.query, index)
self.key = prune_linear_layer(self.key, index)
self.value = prune_linear_layer(self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.num_attention_heads = self.num_attention_heads - len(heads)
self.all_head_size = self.attention_head_size * self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, input_ids, attention_mask=None, head_mask=None):
mixed_query_layer = self.query(input_ids)
mixed_key_layer = self.key(input_ids)
@ -225,7 +212,7 @@ class AlbertLayerGroup(nn.Module):
layer_attentions = layer_attentions + (layer_output[1],)
if self.output_hidden_states:
layer_hidden_states = layer_hidden_states + (hidden_states,)
layer_hidden_states = layer_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
@ -367,6 +354,8 @@ class AlbertModel(BertModel):
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
self.pooler_activation = nn.Tanh()
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
if attention_mask is None:
@ -422,33 +411,41 @@ class AlbertForMaskedLM(BertPreTrainedModel):
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
"""
config_class = AlbertConfig
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_albert
base_model_prefix = "albert"
def __init__(self, config):
super(AlbertForMaskedLM, self).__init__(config)
self.config = config
self.bert = AlbertModel(config)
self.albert = AlbertModel(config)
self.LayerNorm = nn.LayerNorm(config.embedding_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
self.word_embeddings = nn.Linear(config.embedding_size, config.vocab_size)
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
self.activation = ACT2FN[config.hidden_act]
self.init_weights()
self.tie_weights()
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.classifier.word_embeddings,
self.transformer.embeddings.word_embeddings)
self._tie_or_clone_weights(self.decoder,
self.albert.embeddings.word_embeddings)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None):
outputs = self.bert(input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None)
outputs = self.albert(input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None)
sequence_outputs = outputs[0]
hidden_states = self.dense(sequence_outputs)
hidden_states = self.activation(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
prediction_scores = self.word_embeddings(hidden_states)
prediction_scores = self.decoder(hidden_states)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
if masked_lm_labels is not None:

Просмотреть файл

@ -0,0 +1,191 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import shutil
import pytest
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
if is_torch_available():
from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM)
from transformers.modeling_albert import ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
else:
pytestmark = pytest.mark.skip("Require Torch")
class AlbertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
class AlbertModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = AlbertConfig(
vocab_size_or_config_json_file=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_and_check_albert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = AlbertModel(config=config)
model.eval()
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
sequence_output, pooled_output = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_albert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = AlbertForMaskedLM(config=config)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_mask,
sequence_labels, token_labels, choice_labels) = config_and_inputs
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
return config, inputs_dict
def setUp(self):
self.model_tester = AlbertModelTest.AlbertModelTester(self)
self.config_tester = ConfigTester(self, config_class=AlbertConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_albert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_model(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_albert_for_masked_lm(*config_and_inputs)
@pytest.mark.slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = AlbertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
if __name__ == "__main__":
unittest.main()