tf add resize_token_embeddings method (#4351)
* resize token embeddings * add tokens * add tokens * add tokens * add t5 token method * add t5 token method * add t5 token method * typo * debugging input * debugging input * debug * debug * debug * trying to set embedding tokens properly * set embeddings for generation head too * set embeddings for generation head too * debugging * debugging * enable generation * add base method * add base method * add base method * return logits in the main call * reverting to generation * revert back * set embeddings for the bert main layer * description * fix conflicts * logging * set base model as self * refactor * tf_bert add method * tf_bert add method * tf_bert add method * tf_bert add method * tf_bert add method * tf_bert add method * tf_bert add method * tf_bert add method * v0 * v0 * finalize * final * black * add tests * revert back the emb call * comments * comments * add the second test * add vocab size condig * add tf models * add tf models. add common tests * remove model specific embedding tests * stylish * remove files * stylez * Update src/transformers/modeling_tf_transfo_xl.py change the error. Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * adding unchanged weight test Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Родитель
973433260e
Коммит
32e94cff64
|
@ -60,6 +60,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
|
|||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.position_embeddings = tf.keras.layers.Embedding(
|
||||
config.max_position_embeddings,
|
||||
config.embedding_size,
|
||||
|
@ -515,6 +516,10 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
|
|||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
self.embeddings.vocab_size = value.shape[0]
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -497,6 +497,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.initializer_range = config.initializer_range
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
self.embeddings = TFBertEmbeddings(config, name="embeddings")
|
||||
|
@ -506,8 +507,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
self.embeddings.vocab_size = value.shape[0]
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
|
|
|
@ -213,6 +213,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||
def get_input_embeddings(self):
|
||||
return self.w
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.w.weight = value
|
||||
self.w.vocab_size = value.shape[0]
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -422,8 +422,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
|||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
self.embeddings.vocab_size = value.shape[0]
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -217,6 +217,10 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
|
|||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
self.embeddings.vocab_size = value.shape[0]
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -371,9 +375,6 @@ class TFElectraModel(TFElectraPreTrainedModel):
|
|||
super().__init__(config, *inputs, **kwargs)
|
||||
self.electra = TFElectraMainLayer(config, name="electra")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.electra.embeddings
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
|
@ -422,9 +423,6 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
|
|||
self.electra = TFElectraMainLayer(config, name="electra")
|
||||
self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.electra.embeddings
|
||||
|
||||
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
|
||||
def call(
|
||||
self,
|
||||
|
@ -519,9 +517,6 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel):
|
|||
self.activation = config.hidden_act
|
||||
self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.electra.embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.generator_lm_head
|
||||
|
||||
|
|
|
@ -235,8 +235,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||
def get_input_embeddings(self):
|
||||
return self.wte
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
def set_input_embeddings(self, value):
|
||||
self.wte.weight = value
|
||||
self.wte.vocab_size = self.wte.weight.shape[0]
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
|
|
|
@ -227,8 +227,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||
def get_input_embeddings(self):
|
||||
return self.tokens_embed
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
def set_input_embeddings(self, value):
|
||||
self.tokens_embed.weight = value
|
||||
self.tokens_embed.vocab_size = value.shape[0]
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
|
|
|
@ -101,9 +101,6 @@ class TFRobertaMainLayer(TFBertMainLayer):
|
|||
super().__init__(config, **kwargs)
|
||||
self.embeddings = TFRobertaEmbeddings(config, name="embeddings")
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
|
||||
class TFRobertaPreTrainedModel(TFPreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
|
|
|
@ -884,6 +884,16 @@ class TFT5Model(TFT5PreTrainedModel):
|
|||
def get_output_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
self.shared.vocab_size = self.shared.weight.shape[0]
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
pass
|
||||
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
|
@ -1011,6 +1021,15 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
|
|||
def get_output_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.shared.weight = new_embeddings
|
||||
# retrieve correct absolute scope for embed token wrapper
|
||||
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
|
||||
pass
|
||||
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
|
||||
self.encoder.set_embed_tokens(embed_tokens)
|
||||
self.decoder.set_embed_tokens(embed_tokens)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
|
|
|
@ -468,6 +468,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||
def get_input_embeddings(self):
|
||||
return self.word_emb
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
raise NotImplementedError
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
return self.word_emb
|
||||
|
||||
|
|
|
@ -199,6 +199,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
"""
|
||||
Set model's input embeddings
|
||||
|
||||
Args:
|
||||
value (:obj:`tf.keras.layers.Layer`):
|
||||
A module mapping vocabulary to hidden states.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
if base_model is not self:
|
||||
base_model.set_input_embeddings(value)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_embeddings(self):
|
||||
"""
|
||||
Returns the model's output embeddings.
|
||||
|
@ -209,40 +223,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
"""
|
||||
return None # Overwrite for models with output embeddings
|
||||
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||
""" Build a resized Embedding Variable from a provided token Embedding Module.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
|
||||
Args:
|
||||
new_num_tokens: (`optional`) int
|
||||
New number of tokens in the embedding matrix.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
If not provided or None: return the provided token Embedding Module.
|
||||
Return: ``tf.Variable``
|
||||
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
|
||||
"""
|
||||
# if new_num_tokens is None:
|
||||
# return old_embeddings
|
||||
|
||||
# old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
# if old_num_tokens == new_num_tokens:
|
||||
# return old_embeddings
|
||||
|
||||
# # Build new embeddings
|
||||
# new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
||||
# new_embeddings.to(old_embeddings.weight.device)
|
||||
|
||||
# # initialize all new embeddings (in particular added tokens)
|
||||
# self._init_weights(new_embeddings)
|
||||
|
||||
# # Copy token embeddings from the previous weights
|
||||
# num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
# new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
||||
|
||||
# return new_embeddings
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
||||
|
@ -256,7 +236,71 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
Return: ``tf.Variable``
|
||||
Pointer to the input tokens Embeddings Module of the model
|
||||
"""
|
||||
raise NotImplementedError
|
||||
model_embeds = self._resize_token_embeddings(new_num_tokens)
|
||||
if new_num_tokens is None:
|
||||
return model_embeds
|
||||
|
||||
return model_embeds
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
# get_input_embeddings and set_input_embeddings need to be implemented in base layer.
|
||||
base_model = getattr(self, self.base_model_prefix, self)
|
||||
old_embeddings = base_model.get_input_embeddings()
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
base_model.set_input_embeddings(new_embeddings)
|
||||
# Update base model and current model config
|
||||
self.config.vocab_size = new_num_tokens
|
||||
base_model.vocab_size = new_num_tokens
|
||||
return base_model.get_input_embeddings()
|
||||
|
||||
def _get_word_embeddings(self, embeddings):
|
||||
if hasattr(embeddings, "word_embeddings"):
|
||||
# TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings
|
||||
return embeddings.word_embeddings
|
||||
elif hasattr(embeddings, "weight"):
|
||||
# TFSharedEmbeddings
|
||||
return embeddings.weight
|
||||
else:
|
||||
raise ValueError("word embedding is not defined.")
|
||||
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||
""" Build a resized Embedding Variable from a provided token Embedding Module.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end.
|
||||
|
||||
Args:
|
||||
new_num_tokens: (`optional`) int
|
||||
New number of tokens in the embedding matrix.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
If not provided or None: return the provided token Embedding Module.
|
||||
Return: ``tf.Variable``
|
||||
Pointer to the resized word Embedding Module or the old Embedding Module if new_num_tokens is None
|
||||
"""
|
||||
word_embeddings = self._get_word_embeddings(old_embeddings)
|
||||
if new_num_tokens is None:
|
||||
return word_embeddings
|
||||
old_num_tokens, old_embedding_dim = word_embeddings.shape
|
||||
if old_num_tokens == new_num_tokens:
|
||||
return word_embeddings
|
||||
|
||||
# initialize new embeddings
|
||||
# todo: initializer range is not always passed in config.
|
||||
init_range = getattr(self.config, "initializer_range", 0.02)
|
||||
new_embeddings = self.add_weight(
|
||||
"weight",
|
||||
shape=[new_num_tokens, old_embedding_dim],
|
||||
initializer=get_initializer(init_range),
|
||||
dtype=tf.float32,
|
||||
)
|
||||
init_weights = new_embeddings.numpy()
|
||||
|
||||
# Copy token embeddings from the previous weights
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
init_weights[:num_tokens_to_copy] = word_embeddings[:num_tokens_to_copy, :]
|
||||
new_embeddings.assign(init_weights)
|
||||
|
||||
return new_embeddings
|
||||
|
||||
def prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the base model.
|
||||
|
|
|
@ -306,6 +306,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.weight = value
|
||||
self.embeddings.vocab_size = value.shape[0]
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -388,6 +388,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||
def get_input_embeddings(self):
|
||||
return self.word_embedding
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.word_embedding.weight = value
|
||||
self.word_embedding.vocab_size = value.shape[0]
|
||||
|
||||
def build(self, input_shape):
|
||||
initializer = get_initializer(self.initializer_range)
|
||||
self.mask_emb = self.add_weight(
|
||||
|
|
|
@ -472,6 +472,30 @@ class TFModelTesterMixin:
|
|||
|
||||
model(inputs)
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
INPUT_SHAPE = [1, 10, config.hidden_size]
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
emb_old = model.get_input_embeddings()
|
||||
emb_old.build(INPUT_SHAPE)
|
||||
# reshape the embeddings
|
||||
new_embeddings = model._get_resized_embeddings(emb_old, size)
|
||||
# # check that the the resized embeddings size matches the desired size.
|
||||
assert_size = size if size is not None else config.vocab_size
|
||||
self.assertEqual(new_embeddings.shape[0], assert_size)
|
||||
# check that weights remain the same after resizing
|
||||
emd_old_weights = model._get_word_embeddings(emb_old)
|
||||
models_equal = True
|
||||
for p1, p2 in zip(emd_old_weights.numpy(), new_embeddings.numpy()):
|
||||
if np.sum(abs(p1 - p2)) > 0:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||
|
|
|
@ -169,7 +169,6 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
)
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
|
||||
def setUp(self):
|
||||
|
|
Загрузка…
Ссылка в новой задаче