Bugfix: Removal of padding_idx in BartLearnedPositionalEmbedding (#10200)
* Assumption of padding_idx <2 might not stand * Use offset instead of 2 * Fix with black * Change behavior to warning instead for backward compatibility. * Fix with black * Remove warning * Make padding_idx non-required * padding_idx fix for blenderbot * padding_idx fix for blenderbot_small * padding_idx fix for led * padding_idx fix for mbart * Remove extra whitespaces * padding_idx fix for template * Fix padding_idx passed to nn.Embedding mistake * Fixed padding_idx passed to positional embedding in template * Remove padding_idx from pytorch learned positional embeddings * Remove accidentally added quotes * Remove padding_idx from tf learned positional embeddings * Remove zeroing of weights in __init__ Co-authored-by: Wang Ming Rui <mingrui.wang@C02CJTUYMD6M.local>
This commit is contained in:
Родитель
55fe80d084
Коммит
894db6701e
|
@ -108,12 +108,11 @@ class BartLearnedPositionalEmbedding(nn.Embedding):
|
|||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
|
||||
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
# and adjust num_embeddings appropriately. Other models dont have this hack
|
||||
self.offset = 2
|
||||
super().__init__(num_embeddings + self.offset, embedding_dim, padding_idx=padding_idx)
|
||||
super().__init__(num_embeddings + self.offset, embedding_dim)
|
||||
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
|
@ -673,7 +672,6 @@ class BartEncoder(BartPretrainedModel):
|
|||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
embed_dim,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
|
@ -836,7 +834,6 @@ class BartDecoder(BartPretrainedModel):
|
|||
self.embed_positions = BartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
|
|
@ -113,8 +113,7 @@ class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
|
|||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
|
||||
assert padding_idx is not None, "padding_idx cannot be None"
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
|
||||
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
# and adjust num_embeddings appropriately. Other models dont have this hack
|
||||
self.offset = 2
|
||||
|
@ -632,7 +631,6 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||
self.embed_positions = TFBartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
|
@ -793,7 +791,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||
self.embed_positions = TFBartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
|
||||
|
|
|
@ -112,9 +112,8 @@ class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
|
|||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
|
||||
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
super().__init__(num_embeddings, embedding_dim)
|
||||
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
|
@ -635,7 +634,6 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
|||
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
embed_dim,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([BlenderbotEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
@ -800,7 +798,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([BlenderbotDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
|
|
@ -118,8 +118,7 @@ class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings):
|
|||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
|
||||
assert padding_idx is not None, "padding_idx cannot be None"
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
|
||||
super().__init__(num_embeddings, embedding_dim, **kwargs)
|
||||
|
||||
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
|
||||
|
@ -629,7 +628,6 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
|||
self.embed_positions = TFBlenderbotLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.layers = [TFBlenderbotEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
|
@ -797,7 +795,6 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
|||
self.embed_positions = TFBlenderbotLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
|
||||
|
|
|
@ -110,9 +110,8 @@ class BlenderbotSmallLearnedPositionalEmbedding(nn.Embedding):
|
|||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
|
||||
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
super().__init__(num_embeddings, embedding_dim)
|
||||
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
|
@ -636,7 +635,6 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
|||
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
embed_dim,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([BlenderbotSmallEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
|
@ -800,7 +798,6 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([BlenderbotSmallDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
|
|
@ -117,8 +117,7 @@ class TFBlenderbotSmallLearnedPositionalEmbedding(TFSharedEmbeddings):
|
|||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
|
||||
assert padding_idx is not None, "padding_idx cannot be None"
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
|
||||
super().__init__(num_embeddings, embedding_dim, **kwargs)
|
||||
|
||||
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
|
||||
|
@ -634,7 +633,6 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
|||
self.embed_positions = TFBlenderbotSmallLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.layers = [TFBlenderbotSmallEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
|
@ -802,7 +800,6 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
|||
self.embed_positions = TFBlenderbotSmallLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
|
||||
|
|
|
@ -112,9 +112,8 @@ class LEDLearnedPositionalEmbedding(nn.Embedding):
|
|||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
|
||||
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
super().__init__(num_embeddings, embedding_dim)
|
||||
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
|
@ -1622,7 +1621,6 @@ class LEDEncoder(LEDPreTrainedModel):
|
|||
self.embed_positions = LEDLearnedPositionalEmbedding(
|
||||
self.max_source_positions,
|
||||
embed_dim,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([LEDEncoderLayer(config, i) for i in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
|
@ -1891,7 +1889,6 @@ class LEDDecoder(LEDPreTrainedModel):
|
|||
self.embed_positions = LEDLearnedPositionalEmbedding(
|
||||
self.max_target_positions,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([LEDDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
|
|
@ -108,8 +108,7 @@ class TFLEDLearnedPositionalEmbedding(TFSharedEmbeddings):
|
|||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
|
||||
assert padding_idx is not None, "padding_idx cannot be None"
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
|
||||
super().__init__(num_embeddings, embedding_dim, **kwargs)
|
||||
|
||||
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
|
||||
|
@ -1612,7 +1611,6 @@ class TFLEDEncoder(tf.keras.layers.Layer):
|
|||
self.embed_positions = TFLEDLearnedPositionalEmbedding(
|
||||
config.max_encoder_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
|
@ -1865,7 +1863,6 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
|||
self.embed_positions = TFLEDLearnedPositionalEmbedding(
|
||||
config.max_decoder_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.layers = [TFLEDDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
|
||||
|
|
|
@ -114,12 +114,11 @@ class MBartLearnedPositionalEmbedding(nn.Embedding):
|
|||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
|
||||
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
# and adjust num_embeddings appropriately. Other models dont have this hack
|
||||
self.offset = 2
|
||||
super().__init__(num_embeddings + self.offset, embedding_dim, padding_idx=padding_idx)
|
||||
super().__init__(num_embeddings + self.offset, embedding_dim)
|
||||
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
|
@ -678,7 +677,6 @@ class MBartEncoder(MBartPreTrainedModel):
|
|||
self.embed_positions = MBartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
embed_dim,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
|
@ -844,7 +842,6 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||
self.embed_positions = MBartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
|
|
@ -115,8 +115,7 @@ class TFMBartLearnedPositionalEmbedding(TFSharedEmbeddings):
|
|||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
|
||||
assert padding_idx is not None, "padding_idx cannot be None"
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
|
||||
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
# and adjust num_embeddings appropriately. Other models dont have this hack
|
||||
self.offset = 2
|
||||
|
@ -636,7 +635,6 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
|||
self.embed_positions = TFMBartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.layers = [TFMBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
|
@ -806,7 +804,6 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
|||
self.embed_positions = TFMBartLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
|
||||
|
|
|
@ -777,7 +777,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
|||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.distilbert.modeling_tf_distilbert.TFDistilBertModel.serving_output
|
||||
def serving_output(self, output: TFBaseModelOutput) -> TFBaseModelOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
|
@ -800,7 +800,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
|||
|
||||
self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}")
|
||||
self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, input_embeddings=self.{{cookiecutter.lowercase_modelname}}.embeddings, name="mlm___cls")
|
||||
|
||||
|
||||
def get_lm_head(self) -> tf.keras.layers.Layer:
|
||||
return self.mlm.predictions
|
||||
|
||||
|
@ -876,7 +876,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
|||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
|
||||
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
|
@ -975,7 +975,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
|||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output
|
||||
def serving_output(self, output: TFCausalLMOutput) -> TFCausalLMOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
|
@ -1015,7 +1015,7 @@ class TF{{cookiecutter.camelcase_modelname}}ClassificationHead(tf.keras.layers.L
|
|||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""{{cookiecutter.modelname}} Model transformer with a sequence classification/regression head on top
|
||||
"""{{cookiecutter.modelname}} Model transformer with a sequence classification/regression head on top
|
||||
e.g., for GLUE tasks. """,
|
||||
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING,
|
||||
)
|
||||
|
@ -1098,7 +1098,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
|
|||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
|
||||
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
|
@ -1239,7 +1239,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
|||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@tf.function(input_signature=[{
|
||||
"input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"),
|
||||
"attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"),
|
||||
|
@ -1250,7 +1250,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
|||
output = self.call(input_ids=inputs)
|
||||
|
||||
return self.serving_output(output)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
|
||||
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
|
@ -1347,7 +1347,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
|||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
|
||||
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
|
@ -1458,7 +1458,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
|||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
|
||||
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
|
@ -1565,8 +1565,7 @@ class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(TFSharedE
|
|||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, **kwargs):
|
||||
assert padding_idx is not None, "padding_idx cannot be None"
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
|
||||
super().__init__(num_embeddings, embedding_dim, **kwargs)
|
||||
|
||||
def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
|
||||
|
@ -1876,7 +1875,7 @@ class TF{{cookiecutter.camelcase_modelname}}PreTrainedModel(TFPreTrainedModel):
|
|||
"input_ids": input_ids,
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
|
@ -2017,7 +2016,6 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||
self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.layers = [TF{{cookiecutter.camelcase_modelname}}EncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
|
@ -2028,7 +2026,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
|||
|
||||
def set_embed_tokens(self, embed_tokens):
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
|
@ -2160,7 +2158,6 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||
self.embed_positions = TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
|
||||
|
@ -2328,7 +2325,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||
|
||||
if inputs["output_hidden_states"]:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
|
||||
if inputs["output_attentions"]:
|
||||
all_self_attns = list(all_self_attns)
|
||||
|
||||
|
@ -2344,7 +2341,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
@tf.function
|
||||
def compute_combined_attns_mask(self, inputs, input_shape, past_key_values_length):
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
@ -2393,7 +2390,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||
|
||||
self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, embed_tokens, name="encoder")
|
||||
self.decoder = TF{{cookiecutter.camelcase_modelname}}Decoder(config, embed_tokens, name="decoder")
|
||||
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
|
@ -2407,7 +2404,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
|
@ -2503,7 +2500,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
|||
class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
||||
def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
|
||||
self.model = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="model")
|
||||
|
||||
def get_encoder(self):
|
||||
|
@ -2572,7 +2569,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
|||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
|
@ -2580,7 +2577,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
|||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||
|
||||
|
||||
return TFSeq2SeqModelOutput(
|
||||
last_hidden_state=output.last_hidden_state,
|
||||
past_key_values=pkv,
|
||||
|
@ -2614,7 +2611,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
|
||||
|
@ -2623,7 +2620,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||
|
||||
def set_bias(self, value):
|
||||
self.final_logits_bias = value["final_logits_bias"]
|
||||
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.get_input_embeddings()
|
||||
|
||||
|
@ -2725,7 +2722,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
|
||||
encoder_attentions=outputs.encoder_attentions, # 2 of e out
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
|
||||
def serving_output(self, output):
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
|
@ -2733,7 +2730,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
|
|||
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
|
||||
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
|
||||
enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
|
||||
|
||||
|
||||
return TFSeq2SeqLMOutput(
|
||||
logits=output.logits,
|
||||
past_key_values=pkv,
|
||||
|
|
|
@ -690,18 +690,18 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
|||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
||||
1]``:
|
||||
|
||||
|
||||
- 0 corresponds to a `sentence A` token,
|
||||
- 1 corresponds to a `sentence B` token.
|
||||
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
|
@ -710,10 +710,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
|||
`What are position IDs? <../glossary.html#position-ids>`_
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
|
@ -1613,12 +1613,11 @@ def _expand_mask(
|
|||
|
||||
class {{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module learns positional embeddings up to a fixed maximum size.
|
||||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
|
||||
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int):
|
||||
super().__init__(num_embeddings, embedding_dim)
|
||||
|
||||
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
|
@ -2172,7 +2171,6 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
|||
self.embed_positions = {{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
embed_dim,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}EncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
||||
|
@ -2258,7 +2256,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
|||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
assert head_mask.size()[0] == (
|
||||
|
@ -2335,7 +2333,6 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
|||
self.embed_positions = {{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}DecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
||||
|
|
Загрузка…
Ссылка в новой задаче