Templates overhaul 1 (#8993)
This commit is contained in:
Родитель
447808c85f
Коммит
67ff1c314a
|
@ -0,0 +1,65 @@
|
|||
name: Model templates runner
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- "src/**"
|
||||
- "tests/**"
|
||||
- ".github/**"
|
||||
- "templates/**"
|
||||
|
||||
jobs:
|
||||
run_tests_templates:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.6
|
||||
|
||||
- name: Loading cache.
|
||||
uses: actions/cache@v2
|
||||
id: cache
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: v1.2-tests_templates
|
||||
restore-keys: |
|
||||
v1.2-tests_templates-${{ hashFiles('setup.py') }}
|
||||
v1.2-tests_templates
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install .[dev]
|
||||
- name: Create model files
|
||||
run: |
|
||||
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/encoder-bert-tokenizer.json --path=templates/adding_a_new_model
|
||||
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/pt-encoder-bert-tokenizer.json --path=templates/adding_a_new_model
|
||||
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/standalone.json --path=templates/adding_a_new_model
|
||||
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/tf-encoder-bert-tokenizer.json --path=templates/adding_a_new_model
|
||||
make style
|
||||
python utils/check_table.py --fix_and_overwrite
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
|
||||
- name: Run all non-slow tests
|
||||
run: |
|
||||
python -m pytest -n 2 --dist=loadfile -s --make-reports=tests_templates tests/*template*
|
||||
|
||||
- name: Run style changes
|
||||
run: |
|
||||
git fetch origin master:master
|
||||
make fixup
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ always() }}
|
||||
run: cat reports/tests_templates_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: run_all_tests_templates_test_reports
|
||||
path: reports
|
|
@ -19,12 +19,18 @@ from argparse import ArgumentParser, Namespace
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from cookiecutter.main import cookiecutter
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
try:
|
||||
from cookiecutter.main import cookiecutter
|
||||
|
||||
_has_cookiecutter = True
|
||||
except ImportError:
|
||||
_has_cookiecutter = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
@ -49,6 +55,11 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
|||
self._path = path
|
||||
|
||||
def run(self):
|
||||
if not _has_cookiecutter:
|
||||
raise ImportError(
|
||||
"Model creation dependencies are required to use the `add_new_model` command. Install them by running "
|
||||
"the folowing at the root of your `transformers` clone:\n\n\t$ pip install -e .[modelcreation]\n"
|
||||
)
|
||||
# Ensure that there is no other `cookiecutter-template-xxx` directory in the current working directory
|
||||
directories = [directory for directory in os.listdir() if "cookiecutter-template-" == directory[:22]]
|
||||
if len(directories) > 0:
|
||||
|
@ -153,6 +164,11 @@ class AddNewModelCommand(BaseTransformersCLICommand):
|
|||
f"{model_dir}/tokenization_{lowercase_model_name}.py",
|
||||
)
|
||||
|
||||
shutil.move(
|
||||
f"{directory}/tokenization_fast_{lowercase_model_name}.py",
|
||||
f"{model_dir}/tokenization_{lowercase_model_name}_fast.py",
|
||||
)
|
||||
|
||||
from os import fdopen, remove
|
||||
from shutil import copymode, move
|
||||
from tempfile import mkstemp
|
||||
|
|
|
@ -849,7 +849,7 @@ def add_code_sample_docstrings(
|
|||
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
|
||||
doc_kwargs["mask"] = "[MASK]" if mask is None else mask
|
||||
code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
|
||||
elif "LMHead" in model_class:
|
||||
elif "LMHead" in model_class or "CausalLM" in model_class:
|
||||
code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
|
||||
elif "Model" in model_class or "Encoder" in model_class:
|
||||
code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
|
||||
|
|
|
@ -17,20 +17,24 @@
|
|||
# limitations under the License.
|
||||
|
||||
{%- if cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" %}
|
||||
from ...file_utils import is_tf_available, is_torch_available
|
||||
from ...file_utils import is_tf_available, is_torch_available, is_tokenizers_available
|
||||
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "PyTorch" %}
|
||||
from ...file_utils import is_torch_available
|
||||
from ...file_utils import is_torch_available, is_tokenizers_available
|
||||
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow" %}
|
||||
from ...file_utils import is_tf_available
|
||||
from ...file_utils import is_tf_available, is_tokenizers_available
|
||||
{% endif %}
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
|
||||
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer
|
||||
|
||||
if is_tokenizers_available():
|
||||
from .tokenization_{{cookiecutter.lowercase_modelname}}_fast import {{cookiecutter.camelcase_modelname}}TokenizerFast
|
||||
|
||||
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "PyTorch") %}
|
||||
if is_torch_available():
|
||||
from .modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
|
@ -46,6 +50,7 @@ if is_tf_available():
|
|||
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
|
||||
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.modeling_tf_outputs import TFCausalLMOutput
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
MULTIPLE_CHOICE_DUMMY_INPUTS,
|
||||
|
@ -40,6 +41,7 @@ from ...modeling_tf_utils import (
|
|||
TFQuestionAnsweringLoss,
|
||||
TFSequenceClassificationLoss,
|
||||
TFTokenClassificationLoss,
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFSequenceSummary,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
|
@ -111,19 +113,23 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
|||
mode="embedding",
|
||||
training=False,
|
||||
):
|
||||
"""Get token embeddings of inputs.
|
||||
"""
|
||||
Get token embeddings of inputs.
|
||||
|
||||
Args:
|
||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||
mode: string, a valid value is one of "embedding" and "linear".
|
||||
|
||||
Returns:
|
||||
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
|
||||
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
|
||||
linear tensor, float32 with shape [batch_size, length, vocab_size].
|
||||
outputs: If mode == "embedding", output embedding tensor, float32 with shape [batch_size, length,
|
||||
embedding_size]; if mode == "linear", output linear tensor, float32 with shape [batch_size, length,
|
||||
vocab_size].
|
||||
|
||||
Raises:
|
||||
ValueError: if mode is not valid.
|
||||
|
||||
Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
|
@ -161,9 +167,12 @@ class TF{{cookiecutter.camelcase_modelname}}Embeddings(tf.keras.layers.Layer):
|
|||
return embeddings
|
||||
|
||||
def _linear(self, inputs):
|
||||
"""Computes logits by running inputs through a linear layer.
|
||||
"""
|
||||
Computes logits by running inputs through a linear layer.
|
||||
|
||||
Args:
|
||||
inputs: A float32 tensor with shape [batch_size, length, hidden_size]
|
||||
inputs: A float32 tensor with shape [batch_size, length, hidden_size].
|
||||
|
||||
Returns:
|
||||
float32 tensor with shape [batch_size, length, vocab_size].
|
||||
"""
|
||||
|
@ -327,7 +336,6 @@ class TF{{cookiecutter.camelcase_modelname}}Output(tf.keras.layers.Layer):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
@ -336,6 +344,7 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
|
|||
self.intermediate = TF{{cookiecutter.camelcase_modelname}}Intermediate(config, name="intermediate")
|
||||
self.{{cookiecutter.lowercase_modelname}}_output = TF{{cookiecutter.camelcase_modelname}}Output(config, name="output")
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer.call with bert->{{cookiecutter.lowercase_modelname}}
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||
attention_outputs = self.attention(
|
||||
hidden_states, attention_mask, head_mask, output_attentions, training=training
|
||||
|
@ -347,7 +356,7 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
|
|||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
@ -453,20 +462,6 @@ class TF{{cookiecutter.camelcase_modelname}}MLMHead(tf.keras.layers.Layer):
|
|||
return prediction_scores
|
||||
|
||||
|
||||
class TF{{cookiecutter.camelcase_modelname}}NSPHead(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.seq_relationship = tf.keras.layers.Dense(
|
||||
2, kernel_initializer=get_initializer(config.initializer_range), name="seq_relationship"
|
||||
)
|
||||
|
||||
def call(self, pooled_output):
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
|
||||
return seq_relationship_score
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
config_class = {{cookiecutter.camelcase_modelname}}Config
|
||||
|
@ -600,7 +595,6 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPreTrainedModel with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
||||
"""An abstract class to handle weights initialization and
|
||||
a simple interface for downloading and loading pretrained models.
|
||||
|
@ -855,6 +849,97 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
|||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@add_start_docstrings(
|
||||
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
|
||||
)
|
||||
class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
if not config.is_decoder:
|
||||
logger.warning("If you want to use `TF{{cookiecutter.camelcase_modelname}}ForCausalLM` as a standalone, add `is_decoder=True.`")
|
||||
|
||||
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
|
||||
self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, self.{{cookiecutter.lowercase_modelname}}.embeddings, name="mlm___cls")
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.{{cookiecutter.lowercase_modelname}}.embeddings
|
||||
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
output_type=TFCausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
training=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
position_ids=inputs["position_ids"],
|
||||
head_mask=inputs["head_mask"],
|
||||
inputs_embeds=inputs["inputs_embeds"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.mlm(sequence_output, training=inputs["training"])
|
||||
loss = None
|
||||
|
||||
if inputs["labels"] is not None:
|
||||
# shift labels to the left and cut last logit token
|
||||
logits = logits[:, :-1]
|
||||
labels = inputs["labels"][:, 1:]
|
||||
loss = self.compute_loss(labels, logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFCausalLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(tf.keras.layers.Layer):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
@ -1151,7 +1236,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
|||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.{{cookiecutter.uppercase_modelname}}(
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
|
@ -1246,7 +1331,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
|||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
outputs = self.{{cookiecutter.uppercase_modelname}}(
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
token_type_ids=inputs["token_type_ids"],
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
@ -29,10 +28,11 @@ from ...file_utils import (
|
|||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
MaskedLMOutput,
|
||||
MultipleChoiceModelOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
|
@ -157,6 +157,7 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
|||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
|
||||
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
||||
if input_ids is not None:
|
||||
|
@ -174,10 +175,12 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
|||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||
embeddings = inputs_embeds + token_type_embeddings
|
||||
if self.position_embedding_type == "absolute":
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings += position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
@ -202,6 +205,10 @@ class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
|||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
|
@ -236,6 +243,23 @@ class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
|||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
seq_length = hidden_states.size()[1]
|
||||
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in {{cookiecutter.camelcase_modelname}}Model forward() function)
|
||||
|
@ -432,10 +456,11 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
|||
encoder_attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=False,
|
||||
return_dict=True,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
@ -469,15 +494,24 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
|||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
|
@ -664,7 +698,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
|||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||
output_type=BaseModelOutput,
|
||||
output_type=BaseModelOutputWithCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
|
@ -755,10 +789,11 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
|||
if not return_dict:
|
||||
return (sequence_output,) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutput(
|
||||
return BaseModelOutputWithCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
|
@ -863,6 +898,127 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m
|
|||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""{{cookiecutter.modelname}} Model with a `language modeling` head on top for CLM fine-tuning. """, {{cookiecutter.uppercase_modelname}}_START_DOCSTRING
|
||||
)
|
||||
class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
if not config.is_decoder:
|
||||
logger.warning("If you want to use `{{cookiecutter.camelcase_modelname}}ForCausalLM` as a standalone, add `is_decoder=True.`")
|
||||
|
||||
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
|
||||
self.cls = {{cookiecutter.camelcase_modelname}}OnlyMLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
the model is configured as a decoder.
|
||||
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
||||
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``.
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import {{cookiecutter.camelcase_modelname}}Tokenizer, {{cookiecutter.camelcase_modelname}}ForCausalLM, {{cookiecutter.camelcase_modelname}}Config
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = {{cookiecutter.camelcase_modelname}}Tokenizer.from_pretrained('{{cookiecutter.checkpoint_identifier}}')
|
||||
>>> config = {{cookiecutter.camelcase_modelname}}Config.from_pretrained("{{cookiecutter.checkpoint_identifier}}")
|
||||
>>> config.is_decoder = True
|
||||
>>> model = {{cookiecutter.camelcase_modelname}}ForCausalLM.from_pretrained('{{cookiecutter.checkpoint_identifier}}', config=config)
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> prediction_logits = outputs.logits
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.{{cookiecutter.lowercase_modelname}}(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[1:]
|
||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=lm_loss,
|
||||
logits=prediction_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ if is_tf_available():
|
|||
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||
)
|
||||
|
||||
|
@ -134,6 +135,21 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
|
|||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_lm_head(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.is_decoder = True
|
||||
model = TF{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config)
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
prediction_scores = model(inputs)["logits"]
|
||||
self.parent.assertListEqual(
|
||||
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
|
||||
def create_and_check_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
|
@ -224,6 +240,7 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte
|
|||
all_model_classes = (
|
||||
(
|
||||
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
|
@ -249,6 +266,10 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_causal_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_lm_head(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
@ -269,3 +290,31 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte
|
|||
def test_model_from_pretrained(self):
|
||||
model = TF{{cookiecutter.camelcase_modelname}}Model.from_pretrained("{{cookiecutter.checkpoint_identifier}}")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_tf
|
||||
class TF{{cookiecutter.camelcase_modelname}}ModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = TF{{cookiecutter.camelcase_modelname}}ForMaskedLM.from_pretrained("{{cookiecutter.checkpoint_identifier}}")
|
||||
input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
|
||||
# TODO Replace vocab size
|
||||
vocab_size = 32000
|
||||
|
||||
expected_shape = [1, 6, vocab_size]
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
print(output[:, :3, :3])
|
||||
|
||||
# TODO Replace values below with what was printed above.
|
||||
expected_slice = tf.constant(
|
||||
[
|
||||
[
|
||||
[-0.05243197, -0.04498899, 0.05512108],
|
||||
[-0.07444685, -0.01064632, 0.04352357],
|
||||
[-0.05020351, 0.05530146, 0.00700043],
|
||||
]
|
||||
]
|
||||
)
|
||||
tf.debugging.assert_near(output[:, :3, :3], expected_slice, atol=1e-4)
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from tests.test_modeling_common import floats_tensor
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from .test_configuration_common import ConfigTester
|
||||
|
@ -25,9 +26,12 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
|
|||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
{{cookiecutter.camelcase_modelname}}Config,
|
||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
|
@ -122,6 +126,33 @@ class {{cookiecutter.camelcase_modelname}}ModelTester:
|
|||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = self.prepare_config_and_inputs()
|
||||
|
||||
config.is_decoder = True
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
|
@ -133,6 +164,56 @@ class {{cookiecutter.camelcase_modelname}}ModelTester:
|
|||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = {{cookiecutter.camelcase_modelname}}Model(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = {{cookiecutter.camelcase_modelname}}ForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
|
@ -218,6 +299,7 @@ class {{cookiecutter.camelcase_modelname}}ModelTest(ModelTesterMixin, unittest.T
|
|||
(
|
||||
{{cookiecutter.camelcase_modelname}}Model,
|
||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
|
@ -226,6 +308,7 @@ class {{cookiecutter.camelcase_modelname}}ModelTest(ModelTesterMixin, unittest.T
|
|||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = ({{cookiecutter.camelcase_modelname}}ForCausalLM,) if is_torch_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = {{cookiecutter.camelcase_modelname}}ModelTester(self)
|
||||
|
@ -238,6 +321,12 @@ class {{cookiecutter.camelcase_modelname}}ModelTest(ModelTesterMixin, unittest.T
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
@ -258,6 +347,38 @@ class {{cookiecutter.camelcase_modelname}}ModelTest(ModelTesterMixin, unittest.T
|
|||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_model_as_decoder(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
||||
|
||||
def test_model_as_decoder_with_default_input_mask(self):
|
||||
# This regression test was failing with PyTorch < 1.3
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
|
||||
input_mask = None
|
||||
|
||||
self.model_tester.create_and_check_model_as_decoder(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
@ -265,3 +386,23 @@ class {{cookiecutter.camelcase_modelname}}ModelTest(ModelTesterMixin, unittest.T
|
|||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
class {{cookiecutter.camelcase_modelname}}ModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = {{cookiecutter.camelcase_modelname}}ForMaskedLM.from_pretrained("{{cookiecutter.checkpoint_identifier}}")
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
|
||||
# TODO Replace vocab size
|
||||
vocab_size = 32000
|
||||
|
||||
expected_shape = torch.Size((1, 6, vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
# TODO Replace values below with what was printed above.
|
||||
expected_slice = torch.tensor(
|
||||
[[[-0.0483, 0.1188, -0.0313], [-0.0606, 0.1435, 0.0199], [-0.0235, 0.1519, 0.0175]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
|
@ -47,6 +48,7 @@
|
|||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
|
@ -57,10 +59,13 @@
|
|||
)
|
||||
# End.
|
||||
|
||||
# Below: "if is_tokenizers_available():"
|
||||
# Replace with:
|
||||
from models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast
|
||||
|
||||
# Below: "from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
|
||||
# Replace with:
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Tokenizer
|
||||
# End.
|
||||
|
||||
|
||||
|
@ -96,9 +101,9 @@ from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowerca
|
|||
|
||||
# Below: "# Add modeling imports here"
|
||||
# Replace with:
|
||||
|
||||
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
|
@ -117,6 +122,11 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
|
|||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMaskedLM),
|
||||
# End.
|
||||
|
||||
# Below: "# Model for Causal LM mapping"
|
||||
# Replace with:
|
||||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM),
|
||||
# End.
|
||||
|
||||
# Below: "# Model for Masked LM mapping"
|
||||
# Replace with:
|
||||
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMaskedLM),
|
||||
|
@ -151,9 +161,9 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
|
|||
|
||||
# Below: "# Add modeling imports here"
|
||||
# Replace with:
|
||||
|
||||
from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
|
@ -172,6 +182,11 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
|
|||
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM),
|
||||
# End.
|
||||
|
||||
# Below: "# Model for Causal LM mapping"
|
||||
# Replace with:
|
||||
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForCausalLM),
|
||||
# End.
|
||||
|
||||
# Below: "# Model for Masked LM mapping"
|
||||
# Replace with:
|
||||
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM),
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
# coding=utf-8
|
||||
# Copyright {{cookiecutter.authors}} and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for {{cookiecutter.modelname}}."""
|
||||
|
||||
{%- if cookiecutter.tokenizer_type == "Based on BERT" %}
|
||||
from ...utils import logging
|
||||
from ..bert.tokenization_bert_fast import BertTokenizerFast
|
||||
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"{{cookiecutter.checkpoint_identifier}}": "https://huggingface.co/{{cookiecutter.checkpoint_identifier}}/resolve/main/vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"{{cookiecutter.checkpoint_identifier}}": 512,
|
||||
}
|
||||
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"{{cookiecutter.checkpoint_identifier}}": {"do_lower_case": False},
|
||||
}
|
||||
|
||||
|
||||
class {{cookiecutter.camelcase_modelname}}TokenizerFast(BertTokenizerFast):
|
||||
r"""
|
||||
Construct a "fast" {{cookiecutter.modelname}} tokenizer (backed by HuggingFace's `tokenizers` library).
|
||||
|
||||
:class:`~transformers.{{cookiecutter.camelcase_modelname}}TokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs
|
||||
end-to-end tokenization: punctuation splitting and wordpiece.
|
||||
|
||||
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
|
||||
parameters.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
slow_tokenizer_class = {{cookiecutter.camelcase_modelname}}Tokenizer
|
||||
{%- elif cookiecutter.tokenizer_type == "Standalone" %}
|
||||
from typing import List, Optional
|
||||
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"{{cookiecutter.checkpoint_identifier}}": 1024,
|
||||
}
|
||||
|
||||
class {{cookiecutter.camelcase_modelname}}TokenizerFast(PreTrainedTokenizerFast):
|
||||
"""
|
||||
Construct a "fast" {{cookiecutter.modelname}} tokenizer (backed by HuggingFace's `tokenizers` library).
|
||||
|
||||
Args:
|
||||
vocab_file (:obj:`str`):
|
||||
Path to the vocabulary file.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
slow_tokenizer_class = {{cookiecutter.camelcase_modelname}}Tokenizer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
unk_token="<|endoftext|>",
|
||||
bos_token="<|endoftext|>",
|
||||
eos_token="<|endoftext|>",
|
||||
add_prefix_space=False,
|
||||
trim_offsets=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
ByteLevelBPETokenizer(
|
||||
vocab_file=vocab_file,
|
||||
merges_file=merges_file,
|
||||
add_prefix_space=add_prefix_space,
|
||||
trim_offsets=trim_offsets,
|
||||
),
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
**kwargs,
|
||||
)
|
||||
self.add_prefix_space = add_prefix_space
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
|
||||
if token_ids_1 is None:
|
||||
return output
|
||||
|
||||
return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task.
|
||||
{{cookiecutter.modelname}} does not make use of token type ids, therefore a list of zeros is returned.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of zeros.
|
||||
"""
|
||||
sep = [self.sep_token_id]
|
||||
cls = [self.cls_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(cls + token_ids_0 + sep) * [0]
|
||||
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
||||
|
||||
|
||||
{% endif %}
|
|
@ -17,7 +17,6 @@
|
|||
{%- if cookiecutter.tokenizer_type == "Based on BERT" %}
|
||||
from ...utils import logging
|
||||
from ..bert.tokenization_bert import BertTokenizer
|
||||
from ..bert.tokenization_bert_fast import BertTokenizerFast
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -55,30 +54,12 @@ class {{cookiecutter.camelcase_modelname}}Tokenizer(BertTokenizer):
|
|||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
|
||||
|
||||
class {{cookiecutter.camelcase_modelname}}TokenizerFast(BertTokenizerFast):
|
||||
r"""
|
||||
Construct a "fast" {{cookiecutter.modelname}} tokenizer (backed by HuggingFace's `tokenizers` library).
|
||||
|
||||
:class:`~transformers.{{cookiecutter.camelcase_modelname}}TokenizerFast` is identical to :class:`~transformers.BertTokenizerFast` and runs
|
||||
end-to-end tokenization: punctuation splitting and wordpiece.
|
||||
|
||||
Refer to superclass :class:`~transformers.BertTokenizerFast` for usage examples and documentation concerning
|
||||
parameters.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
{%- elif cookiecutter.tokenizer_type == "Standalone" %}
|
||||
from typing import List, Optional
|
||||
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
|
||||
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
|
||||
|
|
|
@ -58,6 +58,13 @@ Tips:
|
|||
:members: forward
|
||||
|
||||
|
||||
{{cookiecutter.camelcase_modelname}}ForCausalLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.{{cookiecutter.camelcase_modelname}}ForCausalLM
|
||||
:members: forward
|
||||
|
||||
|
||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -109,6 +116,13 @@ TF{{cookiecutter.camelcase_modelname}}ForMaskedLM
|
|||
:members: call
|
||||
|
||||
|
||||
TF{{cookiecutter.camelcase_modelname}}ForCausalLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TF{{cookiecutter.camelcase_modelname}}ForCausalLM
|
||||
:members: forward
|
||||
|
||||
|
||||
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"modelname": "EncoderBERT",
|
||||
"uppercase_modelname": "ENCODER_BERT",
|
||||
"lowercase_modelname": "encoder_bert",
|
||||
"camelcase_modelname": "EncoderBert",
|
||||
"modelname": "Template",
|
||||
"uppercase_modelname": "TEMPLATE",
|
||||
"lowercase_modelname": "template",
|
||||
"camelcase_modelname": "Template",
|
||||
"authors": "The HuggingFace Team",
|
||||
"checkpoint_identifier": "brand-new-bert-base-cased",
|
||||
"tokenizer_type": "Based on BERT",
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"modelname": "PTEncoderBERT",
|
||||
"uppercase_modelname": "PT_ENCODER_BERT",
|
||||
"lowercase_modelname": "pt_encoder_bert",
|
||||
"camelcase_modelname": "PtEncoderBert",
|
||||
"modelname": "TemplatePT",
|
||||
"uppercase_modelname": "TEMPLATE_PT",
|
||||
"lowercase_modelname": "template_pt",
|
||||
"camelcase_modelname": "TemplatePt",
|
||||
"authors": "The HuggingFace Team",
|
||||
"checkpoint_identifier": "brand-new-bert-base-cased",
|
||||
"tokenizer_type": "Based on BERT",
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"modelname": "BIEncoderBERT",
|
||||
"uppercase_modelname": "BI_ENCODER_BERT",
|
||||
"lowercase_modelname": "bi_encoder_bert",
|
||||
"camelcase_modelname": "BiEncoderBert",
|
||||
"modelname": "TemplateBI",
|
||||
"uppercase_modelname": "TEMPLATE_BI",
|
||||
"lowercase_modelname": "template_bi",
|
||||
"camelcase_modelname": "TemplateBi",
|
||||
"authors": "The HuggingFace Team",
|
||||
"checkpoint_identifier": "bi-brand-new-bert-base-cased",
|
||||
"tokenizer_type": "Standalone",
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
{
|
||||
"modelname": "TFEncoderBERT",
|
||||
"uppercase_modelname": "TF_ENCODER_BERT",
|
||||
"lowercase_modelname": "tf_encoder_bert",
|
||||
"camelcase_modelname": "TfEncoderBert",
|
||||
"modelname": "TemplateTF",
|
||||
"uppercase_modelname": "TEMPLATE_TF",
|
||||
"lowercase_modelname": "template_tf",
|
||||
"camelcase_modelname": "TemplateTf",
|
||||
"authors": "The HuggingFace Team",
|
||||
"checkpoint_identifier": "brand-new-bert-base-cased",
|
||||
"tokenizer_type": "Based on BERT",
|
||||
|
|
Загрузка…
Ссылка в новой задаче