adding positional embeds masking to TFRoBERTa

This commit is contained in:
thomwolf 2019-12-21 15:24:48 +01:00
Родитель 5b7fb6a4a1
Коммит 77676c27d2
1 изменённых файлов: 29 добавлений и 7 удалений

Просмотреть файл

@ -20,7 +20,6 @@ from __future__ import (absolute_import, division, print_function,
import logging
import numpy as np
import tensorflow as tf
from .configuration_roberta import RobertaConfig
@ -46,17 +45,40 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
super(TFRobertaEmbeddings, self).__init__(config, **kwargs)
self.padding_idx = 1
def create_position_ids_from_input_ids(self, x):
""" Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
"""
mask = tf.cast(tf.math.not_equal(x, self.padding_idx), dtype=tf.int32)
incremental_indicies = tf.math.cumsum(mask, axis=1) * mask
return incremental_indicies + self.padding_idx
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
""" We are provided embeddings directly. We cannot infer which are padded so just generate
sequential position ids.
:param torch.Tensor inputs_embeds:
:return torch.Tensor:
"""
seq_length = shape_list(inputs_embeds)[1]
position_ids = tf.range(self.padding_idx + 1,
seq_length + self.padding_idx + 1,
dtype=tf.int32)[tf.newaxis, :]
return position_ids
def _embedding(self, inputs, training=False):
"""Applies embedding based on inputs tensor."""
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
if input_ids is not None:
seq_length = shape_list(input_ids)[1]
else:
seq_length = shape_list(inputs_embeds)[1]
if position_ids is None:
position_ids = tf.range(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=tf.int32)[tf.newaxis, :]
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
return super(TFRobertaEmbeddings, self)._embedding([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)