diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 018175707..aaf4c55db 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -58,6 +58,7 @@ from .configuration_bert import BertConfig logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "bert-base-uncased" _CONFIG_FOR_DOC = "BertConfig" _TOKENIZER_FOR_DOC = "BertTokenizer" @@ -862,7 +863,7 @@ class BertModel(BertPreTrainedModel): @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="bert-base-uncased", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) @@ -1273,7 +1274,7 @@ class BertForMaskedLM(BertPreTrainedModel): @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="bert-base-uncased", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC, ) @@ -1468,7 +1469,7 @@ class BertForSequenceClassification(BertPreTrainedModel): @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="bert-base-uncased", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, ) @@ -1552,7 +1553,7 @@ class BertForMultipleChoice(BertPreTrainedModel): @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="bert-base-uncased", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC, ) @@ -1647,7 +1648,7 @@ class BertForTokenClassification(BertPreTrainedModel): @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="bert-base-uncased", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC, ) @@ -1737,7 +1738,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="bert-base-uncased", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC, ) diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index c2a217862..becfc29c1 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -56,6 +56,7 @@ from .configuration_mobilebert import MobileBertConfig logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "google/mobilebert-uncased" _CONFIG_FOR_DOC = "MobileBertConfig" _TOKENIZER_FOR_DOC = "MobileBertTokenizer" @@ -818,7 +819,7 @@ class MobileBertModel(MobileBertPreTrainedModel): @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="google/mobilebert-uncased", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC, ) @@ -1033,7 +1034,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel): @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="google/mobilebert-uncased", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC, ) @@ -1204,20 +1205,22 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel): """, MOBILEBERT_START_DOCSTRING, ) +# Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification with Bert->MobileBert all-casing class MobileBertForSequenceClassification(MobileBertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels + self.mobilebert = MobileBertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, self.num_labels) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="google/mobilebert-uncased", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, ) @@ -1253,7 +1256,9 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel): output_hidden_states=output_hidden_states, return_dict=return_dict, ) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) @@ -1286,6 +1291,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel): """, MOBILEBERT_START_DOCSTRING, ) +# Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering with Bert->MobileBert all-casing class MobileBertForQuestionAnswering(MobileBertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] @@ -1302,7 +1308,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel): @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="google/mobilebert-uncased", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC, ) @@ -1403,10 +1409,11 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel): ) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="google/mobilebert-uncased", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC, ) + # Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice.forward with Bert->MobileBert all-casing def forward( self, input_ids=None, @@ -1481,6 +1488,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel): """, MOBILEBERT_START_DOCSTRING, ) +# Copied from transformers.models.bert.modeling_bert.BertForTokenClassification with Bert->MobileBert all-casing class MobileBertForTokenClassification(MobileBertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] @@ -1498,7 +1506,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel): @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="google/mobilebert-uncased", + checkpoint=_CHECKPOINT_FOR_DOC, output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC, ) diff --git a/utils/check_copies.py b/utils/check_copies.py index eabd10cc9..2f6538432 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -73,7 +73,7 @@ def find_code_in_transformers(object_name): _re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)") -_re_replace_pattern = re.compile(r"with\s+(\S+)->(\S+)(?:\s|$)") +_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)") def blackify(code): @@ -93,6 +93,16 @@ def blackify(code): return result[len("class Bla:\n") :] if has_indent else result +def get_indent(code): + lines = code.split("\n") + idx = 0 + while idx < len(lines) and len(lines[idx]) == 0: + idx += 1 + if idx < len(lines): + return re.search(r"^(\s*)\S", lines[idx]).groups()[0] + return 0 + + def is_copy_consistent(filename, overwrite=False): """ Check if the code commented as a copy in `filename` matches the original. @@ -113,7 +123,7 @@ def is_copy_consistent(filename, overwrite=False): # There is some copied code here, let's retrieve the original. indent, object_name, replace_pattern = search.groups() theoretical_code = find_code_in_transformers(object_name) - theoretical_indent = re.search(r"^(\s*)\S", theoretical_code).groups()[0] + theoretical_indent = get_indent(theoretical_code) start_index = line_index + 1 if indent == theoretical_indent else line_index + 2 indent = theoretical_indent @@ -138,10 +148,16 @@ def is_copy_consistent(filename, overwrite=False): # Before comparing, use the `replace_pattern` on the original code. if len(replace_pattern) > 0: - search_patterns = _re_replace_pattern.search(replace_pattern) - if search_patterns is not None: - obj1, obj2 = search_patterns.groups() + patterns = replace_pattern.replace("with", "").split(",") + patterns = [_re_replace_pattern.search(p) for p in patterns] + for pattern in patterns: + if pattern is None: + continue + obj1, obj2, option = pattern.groups() theoretical_code = re.sub(obj1, obj2, theoretical_code) + if option.strip() == "all-casing": + theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code) + theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code) # Test for a diff and act accordingly. if observed_code != theoretical_code: