Merge pull request #2254 from huggingface/fix-tfroberta
adding positional embeds masking to TFRoBERTa
This commit is contained in:
Коммит
645713e2cb
|
@ -20,7 +20,6 @@ from __future__ import (absolute_import, division, print_function,
|
|||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from .configuration_roberta import RobertaConfig
|
||||
|
@ -46,17 +45,40 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
|
|||
super(TFRobertaEmbeddings, self).__init__(config, **kwargs)
|
||||
self.padding_idx = 1
|
||||
|
||||
def create_position_ids_from_input_ids(self, x):
|
||||
""" Replace non-padding symbols with their position numbers. Position numbers begin at
|
||||
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
|
||||
`utils.make_positions`.
|
||||
:param torch.Tensor x:
|
||||
:return torch.Tensor:
|
||||
"""
|
||||
mask = tf.cast(tf.math.not_equal(x, self.padding_idx), dtype=tf.int32)
|
||||
incremental_indicies = tf.math.cumsum(mask, axis=1) * mask
|
||||
return incremental_indicies + self.padding_idx
|
||||
|
||||
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
||||
""" We are provided embeddings directly. We cannot infer which are padded so just generate
|
||||
sequential position ids.
|
||||
:param torch.Tensor inputs_embeds:
|
||||
:return torch.Tensor:
|
||||
"""
|
||||
seq_length = shape_list(inputs_embeds)[1]
|
||||
|
||||
position_ids = tf.range(self.padding_idx + 1,
|
||||
seq_length + self.padding_idx + 1,
|
||||
dtype=tf.int32)[tf.newaxis, :]
|
||||
return position_ids
|
||||
|
||||
def _embedding(self, inputs, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
||||
|
||||
if input_ids is not None:
|
||||
seq_length = shape_list(input_ids)[1]
|
||||
else:
|
||||
seq_length = shape_list(inputs_embeds)[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=tf.int32)[tf.newaxis, :]
|
||||
if input_ids is not None:
|
||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = self.create_position_ids_from_input_ids(input_ids)
|
||||
else:
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
|
||||
return super(TFRobertaEmbeddings, self)._embedding([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче