Refactor checkpoint name in BERT and MobileBERT (#10424)

* Refactor checkpoint name in BERT and MobileBERT

* Add option to check copies

* Add QuestionAnswering

* Add last models

* Make black happy
This commit is contained in:
Sylvain Gugger 2021-03-03 11:21:17 -05:00 коммит произвёл GitHub
Родитель 39f70a4058
Коммит 801ff969ce
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 43 добавлений и 18 удалений

Просмотреть файл

@ -58,6 +58,7 @@ from .configuration_bert import BertConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
_CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"
@ -862,7 +863,7 @@ class BertModel(BertPreTrainedModel):
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
@ -1273,7 +1274,7 @@ class BertForMaskedLM(BertPreTrainedModel):
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
@ -1468,7 +1469,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
@ -1552,7 +1553,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
@ -1647,7 +1648,7 @@ class BertForTokenClassification(BertPreTrainedModel):
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
@ -1737,7 +1738,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)

Просмотреть файл

@ -56,6 +56,7 @@ from .configuration_mobilebert import MobileBertConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/mobilebert-uncased"
_CONFIG_FOR_DOC = "MobileBertConfig"
_TOKENIZER_FOR_DOC = "MobileBertTokenizer"
@ -818,7 +819,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/mobilebert-uncased",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
)
@ -1033,7 +1034,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/mobilebert-uncased",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
@ -1204,20 +1205,22 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
""",
MOBILEBERT_START_DOCSTRING,
)
# Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification with Bert->MobileBert all-casing
class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.mobilebert = MobileBertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/mobilebert-uncased",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
@ -1253,7 +1256,9 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
@ -1286,6 +1291,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
""",
MOBILEBERT_START_DOCSTRING,
)
# Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering with Bert->MobileBert all-casing
class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
@ -1302,7 +1308,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/mobilebert-uncased",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
@ -1403,10 +1409,11 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/mobilebert-uncased",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
# Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice.forward with Bert->MobileBert all-casing
def forward(
self,
input_ids=None,
@ -1481,6 +1488,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
""",
MOBILEBERT_START_DOCSTRING,
)
# Copied from transformers.models.bert.modeling_bert.BertForTokenClassification with Bert->MobileBert all-casing
class MobileBertForTokenClassification(MobileBertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
@ -1498,7 +1506,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/mobilebert-uncased",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)

Просмотреть файл

@ -73,7 +73,7 @@ def find_code_in_transformers(object_name):
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
_re_replace_pattern = re.compile(r"with\s+(\S+)->(\S+)(?:\s|$)")
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
def blackify(code):
@ -93,6 +93,16 @@ def blackify(code):
return result[len("class Bla:\n") :] if has_indent else result
def get_indent(code):
lines = code.split("\n")
idx = 0
while idx < len(lines) and len(lines[idx]) == 0:
idx += 1
if idx < len(lines):
return re.search(r"^(\s*)\S", lines[idx]).groups()[0]
return 0
def is_copy_consistent(filename, overwrite=False):
"""
Check if the code commented as a copy in `filename` matches the original.
@ -113,7 +123,7 @@ def is_copy_consistent(filename, overwrite=False):
# There is some copied code here, let's retrieve the original.
indent, object_name, replace_pattern = search.groups()
theoretical_code = find_code_in_transformers(object_name)
theoretical_indent = re.search(r"^(\s*)\S", theoretical_code).groups()[0]
theoretical_indent = get_indent(theoretical_code)
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
indent = theoretical_indent
@ -138,10 +148,16 @@ def is_copy_consistent(filename, overwrite=False):
# Before comparing, use the `replace_pattern` on the original code.
if len(replace_pattern) > 0:
search_patterns = _re_replace_pattern.search(replace_pattern)
if search_patterns is not None:
obj1, obj2 = search_patterns.groups()
patterns = replace_pattern.replace("with", "").split(",")
patterns = [_re_replace_pattern.search(p) for p in patterns]
for pattern in patterns:
if pattern is None:
continue
obj1, obj2, option = pattern.groups()
theoretical_code = re.sub(obj1, obj2, theoretical_code)
if option.strip() == "all-casing":
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
# Test for a diff and act accordingly.
if observed_code != theoretical_code: