tf add resize_token_embeddings method (#4351)

* resize token embeddings

* add tokens

* add tokens

* add tokens

* add t5 token method

* add t5 token method

* add t5 token method

* typo

* debugging input

* debugging input

* debug

* debug

* debug

* trying to set embedding tokens properly

* set embeddings for generation head too

* set embeddings for generation head too

* debugging

* debugging

* enable generation

* add base method

* add base method

* add base method

* return logits in the main call

* reverting to generation

* revert back

* set embeddings for the bert main layer

* description

* fix conflicts

* logging

* set base model as self

* refactor

* tf_bert add method

* tf_bert add method

* tf_bert add method

* tf_bert add method

* tf_bert add method

* tf_bert add method

* tf_bert add method

* tf_bert add method

* v0

* v0

* finalize

* final

* black

* add tests

* revert back the emb call

* comments

* comments

* add the second test

* add vocab size condig

* add tf models

* add tf models. add common tests

* remove model specific embedding tests

* stylish

* remove files

* stylez

* Update src/transformers/modeling_tf_transfo_xl.py

change the error.

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* adding unchanged weight test

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Deniz 2020-06-18 15:41:26 -07:00 коммит произвёл GitHub
Родитель 973433260e
Коммит 32e94cff64
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
15 изменённых файлов: 159 добавлений и 56 удалений

Просмотреть файл

@ -60,6 +60,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.config = config
self.vocab_size = config.vocab_size
self.position_embeddings = tf.keras.layers.Embedding(
config.max_position_embeddings,
config.embedding_size,
@ -515,6 +516,10 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError

Просмотреть файл

@ -497,6 +497,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions
self.embeddings = TFBertEmbeddings(config, name="embeddings")
@ -506,8 +507,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self):
return self.embeddings
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.

Просмотреть файл

@ -213,6 +213,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self):
return self.w
def set_input_embeddings(self, value):
self.w.weight = value
self.w.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError

Просмотреть файл

@ -422,8 +422,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self):
return self.embeddings
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _prune_heads(self, heads_to_prune):
raise NotImplementedError

Просмотреть файл

@ -217,6 +217,10 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
@ -371,9 +375,6 @@ class TFElectraModel(TFElectraPreTrainedModel):
super().__init__(config, *inputs, **kwargs)
self.electra = TFElectraMainLayer(config, name="electra")
def get_input_embeddings(self):
return self.electra.embeddings
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
r"""
@ -422,9 +423,6 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
self.electra = TFElectraMainLayer(config, name="electra")
self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions")
def get_input_embeddings(self):
return self.electra.embeddings
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING)
def call(
self,
@ -519,9 +517,6 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel):
self.activation = config.hidden_act
self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
def get_input_embeddings(self):
return self.electra.embeddings
def get_output_embeddings(self):
return self.generator_lm_head

Просмотреть файл

@ -235,8 +235,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self):
return self.wte
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def set_input_embeddings(self, value):
self.wte.weight = value
self.wte.vocab_size = self.wte.weight.shape[0]
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.

Просмотреть файл

@ -227,8 +227,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self):
return self.tokens_embed
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
def set_input_embeddings(self, value):
self.tokens_embed.weight = value
self.tokens_embed.vocab_size = value.shape[0]
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.

Просмотреть файл

@ -101,9 +101,6 @@ class TFRobertaMainLayer(TFBertMainLayer):
super().__init__(config, **kwargs)
self.embeddings = TFRobertaEmbeddings(config, name="embeddings")
def get_input_embeddings(self):
return self.embeddings
class TFRobertaPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and

Просмотреть файл

@ -884,6 +884,16 @@ class TFT5Model(TFT5PreTrainedModel):
def get_output_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
def get_encoder(self):
return self.encoder
@ -1011,6 +1021,15 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
def get_output_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
pass
embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)
def get_encoder(self):
return self.encoder

Просмотреть файл

@ -468,6 +468,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self):
return self.word_emb
def set_input_embeddings(self, value):
raise NotImplementedError
def _resize_token_embeddings(self, new_num_tokens):
return self.word_emb

Просмотреть файл

@ -199,6 +199,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
else:
raise NotImplementedError
def set_input_embeddings(self, value):
"""
Set model's input embeddings
Args:
value (:obj:`tf.keras.layers.Layer`):
A module mapping vocabulary to hidden states.
"""
base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self:
base_model.set_input_embeddings(value)
else:
raise NotImplementedError
def get_output_embeddings(self):
"""
Returns the model's output embeddings.
@ -209,40 +223,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
"""
return None # Overwrite for models with output embeddings
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Variable from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
Args:
new_num_tokens: (`optional`) int
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
If not provided or None: return the provided token Embedding Module.
Return: ``tf.Variable``
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
"""
# if new_num_tokens is None:
# return old_embeddings
# old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
# if old_num_tokens == new_num_tokens:
# return old_embeddings
# # Build new embeddings
# new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
# new_embeddings.to(old_embeddings.weight.device)
# # initialize all new embeddings (in particular added tokens)
# self._init_weights(new_embeddings)
# # Copy token embeddings from the previous weights
# num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
# return new_embeddings
def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
@ -256,7 +236,71 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
Return: ``tf.Variable``
Pointer to the input tokens Embeddings Module of the model
"""
raise NotImplementedError
model_embeds = self._resize_token_embeddings(new_num_tokens)
if new_num_tokens is None:
return model_embeds
return model_embeds
def _resize_token_embeddings(self, new_num_tokens):
# get_input_embeddings and set_input_embeddings need to be implemented in base layer.
base_model = getattr(self, self.base_model_prefix, self)
old_embeddings = base_model.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
base_model.set_input_embeddings(new_embeddings)
# Update base model and current model config
self.config.vocab_size = new_num_tokens
base_model.vocab_size = new_num_tokens
return base_model.get_input_embeddings()
def _get_word_embeddings(self, embeddings):
if hasattr(embeddings, "word_embeddings"):
# TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings
return embeddings.word_embeddings
elif hasattr(embeddings, "weight"):
# TFSharedEmbeddings
return embeddings.weight
else:
raise ValueError("word embedding is not defined.")
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
""" Build a resized Embedding Variable from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end.
Args:
new_num_tokens: (`optional`) int
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
If not provided or None: return the provided token Embedding Module.
Return: ``tf.Variable``
Pointer to the resized word Embedding Module or the old Embedding Module if new_num_tokens is None
"""
word_embeddings = self._get_word_embeddings(old_embeddings)
if new_num_tokens is None:
return word_embeddings
old_num_tokens, old_embedding_dim = word_embeddings.shape
if old_num_tokens == new_num_tokens:
return word_embeddings
# initialize new embeddings
# todo: initializer range is not always passed in config.
init_range = getattr(self.config, "initializer_range", 0.02)
new_embeddings = self.add_weight(
"weight",
shape=[new_num_tokens, old_embedding_dim],
initializer=get_initializer(init_range),
dtype=tf.float32,
)
init_weights = new_embeddings.numpy()
# Copy token embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
init_weights[:num_tokens_to_copy] = word_embeddings[:num_tokens_to_copy, :]
new_embeddings.assign(init_weights)
return new_embeddings
def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model.

Просмотреть файл

@ -306,6 +306,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings.weight = value
self.embeddings.vocab_size = value.shape[0]
def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError

Просмотреть файл

@ -388,6 +388,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self):
return self.word_embedding
def set_input_embeddings(self, value):
self.word_embedding.weight = value
self.word_embedding.vocab_size = value.shape[0]
def build(self, input_shape):
initializer = get_initializer(self.initializer_range)
self.mask_emb = self.add_weight(

Просмотреть файл

@ -472,6 +472,30 @@ class TFModelTesterMixin:
model(inputs)
def test_resize_token_embeddings(self):
if not self.test_resize_embeddings:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
INPUT_SHAPE = [1, 10, config.hidden_size]
for model_class in self.all_model_classes:
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
# build the embeddings
model = model_class(config=config)
emb_old = model.get_input_embeddings()
emb_old.build(INPUT_SHAPE)
# reshape the embeddings
new_embeddings = model._get_resized_embeddings(emb_old, size)
# # check that the the resized embeddings size matches the desired size.
assert_size = size if size is not None else config.vocab_size
self.assertEqual(new_embeddings.shape[0], assert_size)
# check that weights remain the same after resizing
emd_old_weights = model._get_word_embeddings(emb_old)
models_equal = True
for p1, p2 in zip(emd_old_weights.numpy(), new_embeddings.numpy()):
if np.sum(abs(p1 - p2)) > 0:
models_equal = False
self.assertTrue(models_equal)
def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]

Просмотреть файл

@ -169,7 +169,6 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
)
test_pruning = True
test_torchscript = True
test_resize_embeddings = True
test_head_masking = True
def setUp(self):