diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index 7fe3a4c2b..a0c2be111 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -60,6 +60,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): super().__init__(**kwargs) self.config = config + self.vocab_size = config.vocab_size self.position_embeddings = tf.keras.layers.Embedding( config.max_position_embeddings, config.embedding_size, @@ -515,6 +516,10 @@ class TFAlbertMainLayer(tf.keras.layers.Layer): def get_input_embeddings(self): return self.embeddings + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + self.embeddings.vocab_size = value.shape[0] + def _resize_token_embeddings(self, new_num_tokens): raise NotImplementedError diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index c4a5b9d5c..eb0c387c2 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -497,6 +497,7 @@ class TFBertMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) self.num_hidden_layers = config.num_hidden_layers + self.initializer_range = config.initializer_range self.output_attentions = config.output_attentions self.embeddings = TFBertEmbeddings(config, name="embeddings") @@ -506,8 +507,9 @@ class TFBertMainLayer(tf.keras.layers.Layer): def get_input_embeddings(self): return self.embeddings - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + self.embeddings.vocab_size = value.shape[0] def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. diff --git a/src/transformers/modeling_tf_ctrl.py b/src/transformers/modeling_tf_ctrl.py index 323075012..5f10d3f32 100644 --- a/src/transformers/modeling_tf_ctrl.py +++ b/src/transformers/modeling_tf_ctrl.py @@ -213,6 +213,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): def get_input_embeddings(self): return self.w + def set_input_embeddings(self, value): + self.w.weight = value + self.w.vocab_size = value.shape[0] + def _resize_token_embeddings(self, new_num_tokens): raise NotImplementedError diff --git a/src/transformers/modeling_tf_distilbert.py b/src/transformers/modeling_tf_distilbert.py index 76e5f42b6..a3fb76e6b 100644 --- a/src/transformers/modeling_tf_distilbert.py +++ b/src/transformers/modeling_tf_distilbert.py @@ -422,8 +422,9 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer): def get_input_embeddings(self): return self.embeddings - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + self.embeddings.vocab_size = value.shape[0] def _prune_heads(self, heads_to_prune): raise NotImplementedError diff --git a/src/transformers/modeling_tf_electra.py b/src/transformers/modeling_tf_electra.py index d29770de4..7162dea2c 100644 --- a/src/transformers/modeling_tf_electra.py +++ b/src/transformers/modeling_tf_electra.py @@ -217,6 +217,10 @@ class TFElectraMainLayer(TFElectraPreTrainedModel): def get_input_embeddings(self): return self.embeddings + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + self.embeddings.vocab_size = value.shape[0] + def _resize_token_embeddings(self, new_num_tokens): raise NotImplementedError @@ -371,9 +375,6 @@ class TFElectraModel(TFElectraPreTrainedModel): super().__init__(config, *inputs, **kwargs) self.electra = TFElectraMainLayer(config, name="electra") - def get_input_embeddings(self): - return self.electra.embeddings - @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING) def call(self, inputs, **kwargs): r""" @@ -422,9 +423,6 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel): self.electra = TFElectraMainLayer(config, name="electra") self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions") - def get_input_embeddings(self): - return self.electra.embeddings - @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING) def call( self, @@ -519,9 +517,6 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel): self.activation = config.hidden_act self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head") - def get_input_embeddings(self): - return self.electra.embeddings - def get_output_embeddings(self): return self.generator_lm_head diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index 9e715477e..91a9cf08e 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -235,8 +235,9 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): def get_input_embeddings(self): return self.wte - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError + def set_input_embeddings(self, value): + self.wte.weight = value + self.wte.vocab_size = self.wte.weight.shape[0] def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. diff --git a/src/transformers/modeling_tf_openai.py b/src/transformers/modeling_tf_openai.py index 28294136a..5b8596d67 100644 --- a/src/transformers/modeling_tf_openai.py +++ b/src/transformers/modeling_tf_openai.py @@ -227,8 +227,9 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer): def get_input_embeddings(self): return self.tokens_embed - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError + def set_input_embeddings(self, value): + self.tokens_embed.weight = value + self.tokens_embed.vocab_size = value.shape[0] def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. diff --git a/src/transformers/modeling_tf_roberta.py b/src/transformers/modeling_tf_roberta.py index 00cadc114..da40db513 100644 --- a/src/transformers/modeling_tf_roberta.py +++ b/src/transformers/modeling_tf_roberta.py @@ -101,9 +101,6 @@ class TFRobertaMainLayer(TFBertMainLayer): super().__init__(config, **kwargs) self.embeddings = TFRobertaEmbeddings(config, name="embeddings") - def get_input_embeddings(self): - return self.embeddings - class TFRobertaPreTrainedModel(TFPreTrainedModel): """ An abstract class to handle weights initialization and diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index 55d3e4ed3..b25965da4 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -884,6 +884,16 @@ class TFT5Model(TFT5PreTrainedModel): def get_output_embeddings(self): return self.shared + def set_input_embeddings(self, new_embeddings): + self.shared.weight = new_embeddings + self.shared.vocab_size = self.shared.weight.shape[0] + # retrieve correct absolute scope for embed token wrapper + with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: + pass + embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name) + self.encoder.set_embed_tokens(embed_tokens) + self.decoder.set_embed_tokens(embed_tokens) + def get_encoder(self): return self.encoder @@ -1011,6 +1021,15 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): def get_output_embeddings(self): return self.shared + def set_input_embeddings(self, new_embeddings): + self.shared.weight = new_embeddings + # retrieve correct absolute scope for embed token wrapper + with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: + pass + embed_tokens = _NoLayerEmbedTokens(self.shared, abs_scope_name=shared_abs_scope_name) + self.encoder.set_embed_tokens(embed_tokens) + self.decoder.set_embed_tokens(embed_tokens) + def get_encoder(self): return self.encoder diff --git a/src/transformers/modeling_tf_transfo_xl.py b/src/transformers/modeling_tf_transfo_xl.py index b55effafd..a3ebb82d0 100644 --- a/src/transformers/modeling_tf_transfo_xl.py +++ b/src/transformers/modeling_tf_transfo_xl.py @@ -468,6 +468,9 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): def get_input_embeddings(self): return self.word_emb + def set_input_embeddings(self, value): + raise NotImplementedError + def _resize_token_embeddings(self, new_num_tokens): return self.word_emb diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 4d31c00bd..e6a4a37a1 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -199,6 +199,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): else: raise NotImplementedError + def set_input_embeddings(self, value): + """ + Set model's input embeddings + + Args: + value (:obj:`tf.keras.layers.Layer`): + A module mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + base_model.set_input_embeddings(value) + else: + raise NotImplementedError + def get_output_embeddings(self): """ Returns the model's output embeddings. @@ -209,40 +223,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): """ return None # Overwrite for models with output embeddings - def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): - """ Build a resized Embedding Variable from a provided token Embedding Module. - Increasing the size will add newly initialized vectors at the end - Reducing the size will remove vectors from the end - - Args: - new_num_tokens: (`optional`) int - New number of tokens in the embedding matrix. - Increasing the size will add newly initialized vectors at the end - Reducing the size will remove vectors from the end - If not provided or None: return the provided token Embedding Module. - Return: ``tf.Variable`` - Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None - """ - # if new_num_tokens is None: - # return old_embeddings - - # old_num_tokens, old_embedding_dim = old_embeddings.weight.size() - # if old_num_tokens == new_num_tokens: - # return old_embeddings - - # # Build new embeddings - # new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) - # new_embeddings.to(old_embeddings.weight.device) - - # # initialize all new embeddings (in particular added tokens) - # self._init_weights(new_embeddings) - - # # Copy token embeddings from the previous weights - # num_tokens_to_copy = min(old_num_tokens, new_num_tokens) - # new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :] - - # return new_embeddings - def resize_token_embeddings(self, new_num_tokens=None): """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. @@ -256,7 +236,71 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): Return: ``tf.Variable`` Pointer to the input tokens Embeddings Module of the model """ - raise NotImplementedError + model_embeds = self._resize_token_embeddings(new_num_tokens) + if new_num_tokens is None: + return model_embeds + + return model_embeds + + def _resize_token_embeddings(self, new_num_tokens): + # get_input_embeddings and set_input_embeddings need to be implemented in base layer. + base_model = getattr(self, self.base_model_prefix, self) + old_embeddings = base_model.get_input_embeddings() + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) + base_model.set_input_embeddings(new_embeddings) + # Update base model and current model config + self.config.vocab_size = new_num_tokens + base_model.vocab_size = new_num_tokens + return base_model.get_input_embeddings() + + def _get_word_embeddings(self, embeddings): + if hasattr(embeddings, "word_embeddings"): + # TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings + return embeddings.word_embeddings + elif hasattr(embeddings, "weight"): + # TFSharedEmbeddings + return embeddings.weight + else: + raise ValueError("word embedding is not defined.") + + def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): + """ Build a resized Embedding Variable from a provided token Embedding Module. + Increasing the size will add newly initialized vectors at the end + Reducing the size will remove vectors from the end. + + Args: + new_num_tokens: (`optional`) int + New number of tokens in the embedding matrix. + Increasing the size will add newly initialized vectors at the end + Reducing the size will remove vectors from the end + If not provided or None: return the provided token Embedding Module. + Return: ``tf.Variable`` + Pointer to the resized word Embedding Module or the old Embedding Module if new_num_tokens is None + """ + word_embeddings = self._get_word_embeddings(old_embeddings) + if new_num_tokens is None: + return word_embeddings + old_num_tokens, old_embedding_dim = word_embeddings.shape + if old_num_tokens == new_num_tokens: + return word_embeddings + + # initialize new embeddings + # todo: initializer range is not always passed in config. + init_range = getattr(self.config, "initializer_range", 0.02) + new_embeddings = self.add_weight( + "weight", + shape=[new_num_tokens, old_embedding_dim], + initializer=get_initializer(init_range), + dtype=tf.float32, + ) + init_weights = new_embeddings.numpy() + + # Copy token embeddings from the previous weights + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + init_weights[:num_tokens_to_copy] = word_embeddings[:num_tokens_to_copy, :] + new_embeddings.assign(init_weights) + + return new_embeddings def prune_heads(self, heads_to_prune): """ Prunes heads of the base model. diff --git a/src/transformers/modeling_tf_xlm.py b/src/transformers/modeling_tf_xlm.py index 53f3e699d..2d76946c5 100644 --- a/src/transformers/modeling_tf_xlm.py +++ b/src/transformers/modeling_tf_xlm.py @@ -306,6 +306,10 @@ class TFXLMMainLayer(tf.keras.layers.Layer): def get_input_embeddings(self): return self.embeddings + def set_input_embeddings(self, value): + self.embeddings.weight = value + self.embeddings.vocab_size = value.shape[0] + def _resize_token_embeddings(self, new_num_tokens): raise NotImplementedError diff --git a/src/transformers/modeling_tf_xlnet.py b/src/transformers/modeling_tf_xlnet.py index 5532bcb2d..c33133893 100644 --- a/src/transformers/modeling_tf_xlnet.py +++ b/src/transformers/modeling_tf_xlnet.py @@ -388,6 +388,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): def get_input_embeddings(self): return self.word_embedding + def set_input_embeddings(self, value): + self.word_embedding.weight = value + self.word_embedding.vocab_size = value.shape[0] + def build(self, input_shape): initializer = get_initializer(self.initializer_range) self.mask_emb = self.add_weight( diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 40803f709..11ceb5ab3 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -472,6 +472,30 @@ class TFModelTesterMixin: model(inputs) + def test_resize_token_embeddings(self): + if not self.test_resize_embeddings: + return + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + INPUT_SHAPE = [1, 10, config.hidden_size] + for model_class in self.all_model_classes: + for size in [config.vocab_size - 10, config.vocab_size + 10, None]: + # build the embeddings + model = model_class(config=config) + emb_old = model.get_input_embeddings() + emb_old.build(INPUT_SHAPE) + # reshape the embeddings + new_embeddings = model._get_resized_embeddings(emb_old, size) + # # check that the the resized embeddings size matches the desired size. + assert_size = size if size is not None else config.vocab_size + self.assertEqual(new_embeddings.shape[0], assert_size) + # check that weights remain the same after resizing + emd_old_weights = model._get_word_embeddings(emb_old) + models_equal = True + for p1, p2 in zip(emd_old_weights.numpy(), new_embeddings.numpy()): + if np.sum(abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + def test_lm_head_model_random_no_beam_search_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] diff --git a/tests/test_modeling_tf_distilbert.py b/tests/test_modeling_tf_distilbert.py index c4fb2f103..05397ff5b 100644 --- a/tests/test_modeling_tf_distilbert.py +++ b/tests/test_modeling_tf_distilbert.py @@ -169,7 +169,6 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase): ) test_pruning = True test_torchscript = True - test_resize_embeddings = True test_head_masking = True def setUp(self):