This commit is contained in:
ArrowLuo 2020-10-30 13:25:05 +08:00
Родитель b0ea315de1
Коммит f08a73055d
24 изменённых файлов: 6768 добавлений и 21 удалений

1
.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1 @@
.idea

21
LICENSE
Просмотреть файл

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2020 ArrowLuo
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

91
README.md Normal file
Просмотреть файл

@ -0,0 +1,91 @@
# Preliminary
Excute below scripts in main folder firstly.
```
cd pytorch_pretrained_bert/bert-base-uncased/
wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt
mv bert-base-uncased-vocab.txt vocab.txt
wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz
tar -xvf bert-base-uncased.tar.gz
rm bert-base-uncased.tar.gz
cd ../../
```
# Finetune on YoucookII
## Retrieval
```
INIT_MODEL=<from second phase>
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_retrieval_task.py \
--do_train --num_thread_reader=16 \
--epochs=5 --batch_size=32 \
--n_pair=1 --n_display=100 \
--youcook_train_csv data/youcookii_train.csv \
--youcook_val_csv data/youcookii_val.csv \
--youcook_caption_path data/youcookii_caption.pickle \
--youcook_features_path_2D data/youcookii_videos_features \
--output_dir ckpt_youcook_retrieval --bert_model bert-base-uncased 、
--do_lower_case --lr 3e-5 --max_words 48 --max_frames 48 \
--batch_size_val 200 --visual_num_hidden_layers 6 \
--datatype youcook --init_model ${INIT_MODEL}
```
## Caption
```
INIT_MODEL=<from second phase>
python -m torch.distributed.launch --nproc_per_node=4 train_transcript_distributed.py \
--do_train --num_thread_reader=4 \
--epochs=5 --batch_size=16 \
--n_pair=-1 --n_display=100 \
--youcook_train_csv data/youcookii_train.csv \
--youcook_val_csv data/youcookii_val.csv \
--youcook_caption_path data/youcookii_caption_transcript.pickle \
--youcook_features_path_2D data/youcookii_videos_features \
--output_dir ckpt_youcook_caption --bert_model bert-base-uncased \
--do_lower_case --lr 3e-5 --max_words 128 --max_frames 96 \
--batch_size_val 64 --visual_num_hidden_layers 6 \
--decoder_num_hidden_layers 3 \
--init_model ${INIT_MODEL}
```
# Pretrain on HowTo100M
## Phase I
```
ROOT_PATH=.
DATA_PATH=${ROOT_PATH}/data
SAVE_PATH=${ROOT_PATH}/models
MODEL_PATH=${ROOT_PATH}/UniVL
python -m torch.distributed.launch --nproc_per_node=8 ${MODEL_PATH}/train_transcript_distributed.py \
--do_pretrain --num_thread_reader=0 --epochs=50 \
--batch_size=1920 --n_pair=3 --n_display=100 \
--bert_model bert-base-uncased --do_lower_case --lr 1e-4 \
--max_words 48 --max_frames 64 --batch_size_val 344 \
--output_dir ${SAVE_PATH}/pre_trained/pre_s3d_L48_V6_D3_Phase1 \
--features_path_2D ${DATA_PATH}/features \
--train_csv ${DATA_PATH}/HowTo100M.csv \
--caption_path ${DATA_PATH}/caption.pickle \
--visual_num_hidden_layers 6 --gradient_accumulation_steps 16 \
--sampled_use_mil --load_checkpoint
```
## Phase II
```
ROOT_PATH=.
DATA_PATH=${ROOT_PATH}/data
SAVE_PATH=${ROOT_PATH}/models
MODEL_PATH=${ROOT_PATH}/UniVL
INIT_MODEL=<from first phase>
python -m torch.distributed.launch --nproc_per_node=8 ${MODEL_PATH}/train_transcript_distributed.py \
--do_pretrain --num_thread_reader=0 --epochs=50 \
--batch_size=960 --n_pair=3 --n_display=100 \
--bert_model bert-base-uncased --do_lower_case --lr 1e-4 \
--max_words 48 --max_frames 64 --batch_size_val 344 \
--output_dir ${SAVE_PATH}/pre_trained/pre_s3d_L48_V6_D3_Phase2 \
--features_path_2D ${DATA_PATH}/features \
--train_csv ${DATA_PATH}/HowTo100M.csv \
--caption_path ${DATA_PATH}/caption.pickle \
--visual_num_hidden_layers 6 --decoder_num_hidden_layers 3 \
--gradient_accumulation_steps 60 \
--cross_model --pretrain_with_joint_sim --sampled_use_mil \
--pretrain_enhance_vmodal --pretrain_without_decoder \
--load_checkpoint --init_model ${INIT_MODEL}
```

9
data_prefetch_unitl.py Normal file
Просмотреть файл

@ -0,0 +1,9 @@
# pip install prefetch_generator
from torch.utils.data import DataLoader
from prefetch_generator import BackgroundGenerator
class DataLoaderX(DataLoader):
def __iter__(self):
# transforms generator into a background-thead generator.
return BackgroundGenerator(super().__iter__(), max_prefetch=1)

28
metrics.py Normal file
Просмотреть файл

@ -0,0 +1,28 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
import numpy as np
def compute_metrics(x):
sx = np.sort(-x, axis=1)
d = np.diag(-x)
d = d[:, np.newaxis]
ind = sx - d
ind = np.where(ind == 0)
ind = ind[1]
metrics = {}
metrics['R1'] = float(np.sum(ind == 0)) / len(ind)
metrics['R5'] = float(np.sum(ind < 5)) / len(ind)
metrics['R10'] = float(np.sum(ind < 10)) / len(ind)
metrics['MR'] = np.median(ind) + 1
return metrics
def print_computed_metrics(metrics):
r1 = metrics['R1']
r5 = metrics['R5']
r10 = metrics['R10']
mr = metrics['MR']
print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr))

Просмотреть файл

Просмотреть файл

@ -0,0 +1,117 @@
""" Manage beam search info structure.
Heavily borrowed from OpenNMT-py.
For code in OpenNMT-py, please check the following link:
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py
"""
import torch
import numpy as np
class Constants():
def __init__(self):
self.PAD = 0
self.UNK = 1
self.BOS = 2
self.EOS = 3
self.PAD_WORD = '[MASK]'
self.UNK_WORD = '[UNK]'
self.BOS_WORD = '[CLS]'
self.EOS_WORD = '[SEP]'
@classmethod
def from_tokenizer(cls, tokenizer):
instance = cls()
instance.PAD = tokenizer.vocab[instance.PAD_WORD]
instance.UNK = tokenizer.vocab[instance.UNK_WORD]
instance.BOS = tokenizer.vocab[instance.BOS_WORD]
instance.EOS = tokenizer.vocab[instance.EOS_WORD]
return instance
class Beam():
''' Beam search '''
def __init__(self, size, device=False, tokenizer=None):
if tokenizer is None:
self.constants = Constants()
else:
self.constants = Constants.from_tokenizer(tokenizer)
self.size = size
self._done = False
# The score for each interface on the beam.
self.scores = torch.zeros((size,), dtype=torch.float, device=device)
self.all_scores = []
# The backpointers at each time-step.
self.prev_ks = []
# The outputs at each time-step.
self.next_ys = [torch.full((size,), self.constants.BOS, dtype=torch.long, device=device)]
def get_current_state(self):
"Get the outputs for the current timestep."
return self.get_tentative_hypothesis()
def get_current_origin(self):
"Get the backpointers for the current timestep."
return self.prev_ks[-1]
@property
def done(self):
return self._done
def advance(self, word_prob, word_length=None):
"Update beam status and check if finished or not."
num_words = word_prob.size(1)
# Sum the previous scores.
if len(self.prev_ks) > 0:
beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
else:
beam_lk = word_prob[0]
flat_beam_lk = beam_lk.view(-1)
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort
self.all_scores.append(self.scores)
self.scores = best_scores
# bestScoresId is flattened as a (beam x word) array,
# so we need to calculate which word and beam each score came from
prev_k = best_scores_id / num_words
self.prev_ks.append(prev_k)
self.next_ys.append(best_scores_id - prev_k * num_words)
# End condition is when top-of-beam is EOS.
if self.next_ys[-1][0].item() == self.constants.EOS:
self._done = True
return self._done
def sort_scores(self):
"Sort the scores."
return torch.sort(self.scores, 0, True)
def get_the_best_score_and_idx(self):
"Get the score of the best in the beam."
scores, ids = self.sort_scores()
return scores[1], ids[1]
def get_tentative_hypothesis(self):
"Get the decoded sequence for the current timestep."
if len(self.next_ys) == 1:
dec_seq = self.next_ys[0].unsqueeze(1)
else:
_, keys = self.sort_scores()
hyps = [self.get_hypothesis(k) for k in keys]
hyps = [[self.constants.BOS] + h for h in hyps]
dec_seq = torch.LongTensor(hyps)
return dec_seq
def get_hypothesis(self, k):
""" Walk back to construct the full hypothesis. """
hyp = []
for j in range(len(self.prev_ks) - 1, -1, -1):
hyp.append(self.next_ys[j+1][k])
k = self.prev_ks[j][k]
return list(map(lambda x: x.item(), hyp[::-1]))

Просмотреть файл

@ -0,0 +1,711 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
}
CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02):
"""Constructs BertConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
bert_model_embedding_weights.size(0),
bias=False)
self.decoder.weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertOnlyMLMHead, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertOnlyNSPHead(nn.Module):
def __init__(self, config):
super(BertOnlyNSPHead, self).__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class BertPreTrainingHeads(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertPreTrainingHeads, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class PreTrainedBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedBertModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def resize_token_embeddings(self, new_num_tokens=None):
model = self.module if hasattr(self, 'module') else self
base_model = getattr(model, "bert", model) # get the base model if needed
old_embeddings = base_model.embeddings.word_embeddings
# get_resized_embeddings <-----
if new_num_tokens is None:
return old_embeddings
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
if old_num_tokens == new_num_tokens:
return old_embeddings
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
device = old_embeddings.weight.device
new_embeddings.to(device)
# initialize all new embeddings (in particular added tokens)
self.init_bert_weights(new_embeddings)
# Copy word embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
# ----->
# Update base model and current model config
base_model.embeddings.word_embeddings = new_embeddings
base_model.config.vocab_size = new_num_tokens
model.config.vocab_size = new_num_tokens
model.cls = BertOnlyMLMHead(model.config, base_model.embeddings.word_embeddings.weight).to(device)
@classmethod
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
'''
abstract from `from_pretrained`
'''
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
if os.path.exists(archive_file) is False:
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
if task_config is None or task_config.local_rank == 0:
logger.error(
"Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
archive_file))
return None
if resolved_archive_file == archive_file:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {}".format(archive_file))
else:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
if task_config is None or task_config.local_rank == 0:
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = BertConfig.from_json_file(config_file)
config.type_vocab_size = type_vocab_size
if task_config is None or task_config.local_rank == 0:
logger.info("Model config {}".format(config))
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
if os.path.exists(weights_path):
state_dict = torch.load(weights_path)
else:
if task_config is None or task_config.local_rank == 0:
logger.info("Weight doesn't exsits. {}".format(weights_path))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return config, state_dict
@classmethod
def init_preweight(cls, model, state_dict):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
if len(error_msgs) > 0:
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
return model
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None,
cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_classes for BertForSequenceClassification)
"""
config, state_dict = PreTrainedBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict)
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
return model
model = PreTrainedBertModel.init_preweight(model, state_dict)
return model
class BertModel(PreTrainedBertModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
config: a BertConfig class instance with the configuration to build a new model
Inputs:
`type`: a str, indicates which masking will be used in the attention, choice from [`bi`, `seq`, `gen`]
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.BertModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
def forward(self, type, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True,
attention_sentenceA_mask=None,
attention_sentenceB_mask=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output

Просмотреть файл

@ -0,0 +1,12 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 1024,
"num_attention_heads": 12,
"num_hidden_layers": 2,
"vocab_size": 768
}

Просмотреть файл

@ -0,0 +1,624 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'cross-bert-base': "[TODO]",
'cross-bert-large': "[TODO]",
}
CONFIG_NAME = 'cross_bert_config.json'
WEIGHTS_NAME = 'cross_pytorch_model.bin'
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class CrossBertConfig(object):
"""Configuration class to store the configuration of a `CrossBertModel`.
"""
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02):
"""Constructs CrossBertConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossBertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`CrossBertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `CrossBertConfig` from a Python dictionary of parameters."""
config = CrossBertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `CrossBertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as CrossBertLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class CrossBertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(CrossBertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class CrossBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(CrossBertEmbeddings, self).__init__()
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, concat_embeddings, concat_type=None):
batch_size, seq_length = concat_embeddings.size(0), concat_embeddings.size(1)
if concat_type is None:
concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device)
position_ids = torch.arange(seq_length, dtype=torch.long, device=concat_embeddings.device)
position_ids = position_ids.unsqueeze(0).expand(concat_embeddings.size(0), -1)
token_type_embeddings = self.token_type_embeddings(concat_type)
position_embeddings = self.position_embeddings(position_ids)
embeddings = concat_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class CrossBertSelfAttention(nn.Module):
def __init__(self, config):
super(CrossBertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in CrossBertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class CrossBertSelfOutput(nn.Module):
def __init__(self, config):
super(CrossBertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class CrossBertAttention(nn.Module):
def __init__(self, config):
super(CrossBertAttention, self).__init__()
self.self = CrossBertSelfAttention(config)
self.output = CrossBertSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output
class CrossBertIntermediate(nn.Module):
def __init__(self, config):
super(CrossBertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class CrossBertOutput(nn.Module):
def __init__(self, config):
super(CrossBertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class CrossBertLayer(nn.Module):
def __init__(self, config):
super(CrossBertLayer, self).__init__()
self.attention = CrossBertAttention(config)
self.intermediate = CrossBertIntermediate(config)
self.output = CrossBertOutput(config)
def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class CrossBertEncoder(nn.Module):
def __init__(self, config):
super(CrossBertEncoder, self).__init__()
layer = CrossBertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class CrossBertPooler(nn.Module):
def __init__(self, config):
super(CrossBertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class CrossBertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(CrossBertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class CrossBertLMPredictionHead(nn.Module):
def __init__(self, config, cross_bert_model_embedding_weights):
super(CrossBertLMPredictionHead, self).__init__()
self.transform = CrossBertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(cross_bert_model_embedding_weights.size(1),
cross_bert_model_embedding_weights.size(0),
bias=False)
self.decoder.weight = cross_bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(cross_bert_model_embedding_weights.size(0)))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
class CrossBertOnlyMLMHead(nn.Module):
def __init__(self, config, cross_bert_model_embedding_weights):
super(CrossBertOnlyMLMHead, self).__init__()
self.predictions = CrossBertLMPredictionHead(config, cross_bert_model_embedding_weights)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class CrossBertOnlyNSPHead(nn.Module):
def __init__(self, config):
super(CrossBertOnlyNSPHead, self).__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class CrossBertPreTrainingHeads(nn.Module):
def __init__(self, config, cross_bert_model_embedding_weights):
super(CrossBertPreTrainingHeads, self).__init__()
self.predictions = CrossBertLMPredictionHead(config, cross_bert_model_embedding_weights)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class PreTrainedCrossBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedCrossBertModel, self).__init__()
if not isinstance(config, CrossBertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `CrossBertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_cross_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, CrossBertLayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
'''
abstract from `from_pretrained`
'''
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
if os.path.exists(archive_file) is False:
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
if task_config is None or task_config.local_rank == 0:
logger.error(
"Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
archive_file))
return None
if resolved_archive_file == archive_file:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {}".format(archive_file))
else:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
if task_config is None or task_config.local_rank == 0:
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = CrossBertConfig.from_json_file(config_file)
config.type_vocab_size = type_vocab_size
if task_config is None or task_config.local_rank == 0:
logger.info("Model config {}".format(config))
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
if os.path.exists(weights_path):
state_dict = torch.load(weights_path)
else:
if task_config is None or task_config.local_rank == 0:
logger.info("Weight doesn't exsits. {}".format(weights_path))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return config, state_dict
@classmethod
def init_preweight(cls, model, state_dict):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'cross_bert') else 'cross_bert.')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
if len(error_msgs) > 0:
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
return model
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None,
cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
"""
Instantiate a PreTrainedCrossBertModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `cross_bert-base`
. `cross_bert-large`
- a path or url to a pretrained model archive containing:
. `cross-bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a CrossBertForPreTraining instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific CrossBert class
(ex: num_classes for CrossBertForSequenceClassification)
"""
config, state_dict = PreTrainedCrossBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict)
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
return model
model = PreTrainedCrossBertModel.init_preweight(model, state_dict)
return model
class CrossBertModel(PreTrainedCrossBertModel):
def __init__(self, config):
super(CrossBertModel, self).__init__(config)
self.embeddings = CrossBertEmbeddings(config)
self.encoder = CrossBertEncoder(config)
self.pooler = CrossBertPooler(config)
self.apply(self.init_cross_bert_weights)
def forward(self, type, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True,
attention_sentenceA_mask=None,
attention_sentenceB_mask=None):
if attention_mask is None:
attention_mask = torch.ones(concat_input.size(0), concat_input.size(1))
if concat_type is None:
concat_type = torch.zeros_like(attention_mask)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(concat_input, concat_type)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output

Просмотреть файл

@ -0,0 +1,14 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"type_vocab_size": 2,
"vocab_size": 30522,
"num_decoder_layers": 1,
"max_target_embeddings": 512
}

Просмотреть файл

@ -0,0 +1,713 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'decoder-bert-base': "[TODO]",
'decoder-bert-large': "[TODO]",
}
CONFIG_NAME = 'decoder_bert_config.json'
WEIGHTS_NAME = 'decoder_pytorch_model.bin'
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_vocab_size=2,
initializer_range=0.02,
max_target_embeddings=128,
num_decoder_layers=1):
"""Constructs BertConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
max_target_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
num_decoder_layers:
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.max_target_embeddings = max_target_embeddings
self.num_decoder_layers = num_decoder_layers
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as DecoderBertLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class DecoderBertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(DecoderBertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
bert_model_embedding_weights.size(0),
bias=False)
self.decoder.weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertOnlyMLMHead, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class PreTrainedDecoderBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedDecoderBertModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, DecoderBertLayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def resize_token_embeddings(self, new_num_tokens=None):
model = self.module if hasattr(self, 'module') else self
base_model = getattr(model, "bert", model) # get the base model if needed
old_embeddings = base_model.embeddings.word_embeddings
# get_resized_embeddings <-----
if new_num_tokens is None:
return old_embeddings
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
if old_num_tokens == new_num_tokens:
return old_embeddings
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
device = old_embeddings.weight.device
new_embeddings.to(device)
# initialize all new embeddings (in particular added tokens)
self.init_bert_weights(new_embeddings)
# Copy word embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
# ----->
# Update base model and current model config
base_model.embeddings.word_embeddings = new_embeddings
base_model.config.vocab_size = new_num_tokens
model.config.vocab_size = new_num_tokens
model.cls = BertOnlyMLMHead(model.config, base_model.embeddings.word_embeddings.weight).to(device)
@classmethod
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
'''
abstract from `from_pretrained`
'''
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
if os.path.exists(archive_file) is False:
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
if task_config is None or task_config.local_rank == 0:
logger.error(
"Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
archive_file))
return None
if resolved_archive_file == archive_file:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {}".format(archive_file))
else:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
if task_config is None or task_config.local_rank == 0:
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = BertConfig.from_json_file(config_file)
config.type_vocab_size = type_vocab_size
if task_config is None or task_config.local_rank == 0:
logger.info("Model config {}".format(config))
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
if os.path.exists(weights_path):
state_dict = torch.load(weights_path)
else:
if task_config is None or task_config.local_rank == 0:
logger.info("Weight doesn't exsits. {}".format(weights_path))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return config, state_dict
@classmethod
def init_preweight(cls, model, state_dict):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
if len(error_msgs) > 0:
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
return model
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None,
cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
"""
Instantiate a PreTrainedDecoderBertModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_classes for BertForSequenceClassification)
"""
config, state_dict = PreTrainedDecoderBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict)
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
return model
model = PreTrainedDecoderBertModel.init_preweight(model, state_dict)
return model
# import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
attn = self.dropout(attn)
output = torch.bmm(attn, v)
return output, attn
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, config):
super(MultiHeadAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, q, k, v, attention_mask):
mixed_query_layer = self.query(q)
mixed_key_layer = self.key(k)
mixed_value_layer = self.value(v)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer, attention_scores
class PositionwiseFeedForward(nn.Module):
''' A two-feed-forward-layer module '''
def __init__(self, d_in, d_hid, dropout=0.1):
super().__init__()
self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
self.layer_norm = nn.LayerNorm(d_in)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
output = x.transpose(1, 2)
output = self.w_2(gelu(self.w_1(output)))
output = output.transpose(1, 2)
output = self.dropout(output)
output = self.layer_norm(output + residual)
return output
class DecoderAttention(nn.Module):
def __init__(self, config):
super(DecoderAttention, self).__init__()
self.att = MultiHeadAttention(config)
self.output = BertSelfOutput(config)
def forward(self, q, k, v, attention_mask):
att_output, attention_probs = self.att(q, k, v, attention_mask)
attention_output = self.output(att_output, q)
return attention_output, attention_probs
class DecoderLayer(nn.Module):
def __init__(self, config):
super(DecoderLayer, self).__init__()
self.slf_attn = DecoderAttention(config)
self.enc_attn = DecoderAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None):
slf_output, _ = self.slf_attn(dec_input, dec_input, dec_input, slf_attn_mask)
dec_output, dec_att_scores = self.enc_attn(slf_output, enc_output, enc_output, dec_enc_attn_mask)
intermediate_output = self.intermediate(dec_output)
dec_output = self.output(intermediate_output, dec_output)
return dec_output, dec_att_scores
class BertDecoderEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config, bert_word_embeddings_weight, bert_position_embeddings_weight):
super(BertDecoderEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_target_embeddings, config.hidden_size)
self.word_embeddings.weight = bert_word_embeddings_weight
self.position_embeddings.weight = bert_position_embeddings_weight
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertDecoder(nn.Module):
def __init__(self, config):
super(BertDecoder, self).__init__()
layer = DecoderLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_decoder_layers)])
def forward(self, hidden_states, encoder_outs, self_attn_mask, attention_mask, output_all_encoded_layers=False):
dec_att_scores = None
all_encoder_layers = []
all_dec_att_probs = []
for layer_module in self.layer:
hidden_states, dec_att_scores = layer_module(hidden_states, encoder_outs, self_attn_mask, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
all_dec_att_probs.append(dec_att_scores)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
all_dec_att_probs.append(dec_att_scores)
return all_encoder_layers, all_dec_att_probs
class BertDecoderClassifier(nn.Module):
def __init__(self, config, embedding_weights):
super(BertDecoderClassifier, self).__init__()
self.cls = BertOnlyMLMHead(config, embedding_weights)
def forward(self, hidden_states):
cls_scores = self.cls(hidden_states)
return cls_scores
class DecoderBertModel(nn.Module):
def init_decoder_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, DecoderBertLayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
final_norm (bool, optional): apply layer norm to the output of the
final decoder layer (default: True).
"""
def __init__(self, config, bert_word_embeddings_weight, bert_position_embeddings_weight):
super(DecoderBertModel, self).__init__()
self.config = config
self.max_target_length = config.max_target_embeddings
self.embeddings = BertDecoderEmbeddings(config, bert_word_embeddings_weight, bert_position_embeddings_weight)
self.decoder = BertDecoder(config)
self.classifier = BertDecoderClassifier(config, bert_word_embeddings_weight)
self.apply(self.init_decoder_bert_weights)
def forward(self, input_ids, encoder_input_ids, encoder_outs=None, answer_mask=None, encoder_mask=None):
"""
Args:
input_ids (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing
encoder_outs (Tensor, optional): output from the encoder, used for encoder-side attention
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len, vocab)`
- the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)`
"""
embedding_output = self.embeddings(input_ids)
extended_encoder_mask = encoder_mask.unsqueeze(1).unsqueeze(2) # b x 1 x 1 x ls
extended_encoder_mask = extended_encoder_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_encoder_mask = (1.0 - extended_encoder_mask) * -10000.0
extended_answer_mask = answer_mask.unsqueeze(1).unsqueeze(2)
extended_answer_mask = extended_answer_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
sz_b, len_s, _ = embedding_output.size()
subsequent_mask = torch.triu(torch.ones((len_s, len_s), device=embedding_output.device, dtype=embedding_output.dtype), diagonal=1)
self_attn_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1).unsqueeze(1) # b x 1 x ls x ls
slf_attn_mask = ((1.0 - extended_answer_mask) + self_attn_mask).gt(0).to(dtype=next(self.parameters()).dtype)
self_attn_mask = slf_attn_mask * -10000.0
decoded_layers, dec_att_scores = self.decoder(embedding_output,
encoder_outs,
self_attn_mask,
extended_encoder_mask,
)
sequence_output = decoded_layers[-1]
cls_scores = self.classifier(sequence_output)
return cls_scores

Просмотреть файл

@ -0,0 +1,239 @@
"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
import os
import logging
import shutil
import tempfile
import json
from urllib.parse import urlparse
from pathlib import Path
from typing import Optional, Tuple, Union, IO, Callable, Set
from hashlib import sha256
from functools import wraps
from tqdm import tqdm
import boto3
from botocore.exceptions import ClientError
import requests
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
Path.home() / '.pytorch_pretrained_bert'))
def url_to_filename(url: str, etag: str = None) -> str:
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""
url_bytes = url.encode('utf-8')
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode('utf-8')
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()
return filename
def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise FileNotFoundError("file {} not found".format(cache_path))
meta_path = cache_path + '.json'
if not os.path.exists(meta_path):
raise FileNotFoundError("file {} not found".format(meta_path))
with open(meta_path) as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']
return url, etag
def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == '':
# File, but it doesn't exist.
raise FileNotFoundError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url: str) -> Tuple[str, str]:
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func: Callable):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url: str, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise FileNotFoundError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url: str) -> Optional[str]:
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url: str, temp_file: IO) -> None:
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url: str, temp_file: IO) -> None:
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError("HEAD request failed for url {} with status code {}"
.format(url, response.status_code))
etag = response.headers.get("ETag")
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file)
else:
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w') as meta_file:
json.dump(meta, meta_file)
logger.info("removing temp file %s", temp_file.name)
return cache_path
def read_set_from_file(filename: str) -> Set[str]:
'''
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
'''
collection = set()
with open(filename, 'r', encoding='utf-8') as file_:
for line in file_:
collection.add(line.rstrip())
return collection
def get_file_extension(path: str, dot=True, lower: bool = True):
ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:]
return ext.lower() if lower else ext

Просмотреть файл

@ -0,0 +1,634 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from pytorch_pretrained_bert.bert_module import PreTrainedBertModel, BertModel, BertLayerNorm, BertOnlyMLMHead
from pytorch_pretrained_bert.visual_module import PreTrainedVisualBertModel, VisualBertModel, VisualBertLayerNorm, VisualBertOnlyMLMHead
from pytorch_pretrained_bert.cross_module import PreTrainedCrossBertModel, CrossBertModel, CrossBertLayerNorm
from pytorch_pretrained_bert.decoder_module import PreTrainedDecoderBertModel, DecoderBertModel, DecoderBertLayerNorm
logger = logging.getLogger(__name__)
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class PreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, bert_config, visual_config, cross_config, decoder_config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
self.bert_config = bert_config
self.visual_config = visual_config
self.cross_config = cross_config
self.decoder_config = decoder_config
self.bert = None
self.visual = None
self.cross = None
self.decoder = None
def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
elif isinstance(module, BertLayerNorm) \
or isinstance(module, LayerNorm) \
or isinstance(module, VisualBertLayerNorm) \
or isinstance(module, CrossBertLayerNorm) \
or isinstance(module, DecoderBertLayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def resize_token_embeddings(self, new_num_tokens=None):
raise ValueError("`TODO: Just bert has this function, call it at here.`")
@classmethod
def init_preweight(cls, model, state_dict, prefix=None, task_config=None):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
if prefix is not None:
old_keys = []
new_keys = []
for key in state_dict.keys():
old_keys.append(key)
new_keys.append(prefix+key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='')
if prefix is None and (task_config is None or task_config.local_rank == 0):
logger.info("-"*20)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
if len(error_msgs) > 0:
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
return model
@classmethod
def from_pretrained(cls, pretrained_model_name, pretrained_visual_model_name, pretrained_cross_model_name,
pretrained_decoder_model_name,
state_dict=None, cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
task_config = None
if "task_config" in kwargs.keys():
task_config = kwargs["task_config"]
if not hasattr(task_config, "local_rank"):
task_config.__dict__["local_rank"] = 0
elif task_config.local_rank == -1:
task_config.local_rank = 0
bert_config, state_dict = PreTrainedBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=task_config)
# visual_state_dict is no use here
visual_config, _ = PreTrainedVisualBertModel.get_config(pretrained_visual_model_name, cache_dir, state_dict=None, task_config=task_config)
# visual_state_dict is no use here
cross_config, _ = PreTrainedCrossBertModel.get_config(pretrained_cross_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
decoder_config, _ = PreTrainedDecoderBertModel.get_config(pretrained_decoder_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
# Instantiate model.
model = cls(bert_config, visual_config, cross_config, decoder_config, *inputs, **kwargs)
assert model.bert is not None
assert model.visual is not None
# assert model.cross is not None
# assert model.decoder is not None
if state_dict is not None:
model = cls.init_preweight(model, state_dict, task_config=task_config)
return model
class CrossEn(nn.Module):
def __init__(self,):
super(CrossEn, self).__init__()
def forward(self, sim_matrix):
logpt = F.log_softmax(sim_matrix, dim=-1)
logpt = torch.diag(logpt)
nce_loss = -logpt
sim_loss = nce_loss.mean()
return sim_loss
class MILNCELoss(nn.Module):
def __init__(self, batch_size=1, n_pair=1,):
super(MILNCELoss, self).__init__()
self.batch_size = batch_size
self.n_pair = n_pair
def forward(self, sim_matrix):
mm_mask = np.eye(self.batch_size)
mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair)))
mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device)
from_text_matrix = sim_matrix + mm_mask * -1e12
from_video_matrix = sim_matrix.transpose(1, 0) # video * text
new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1)
logpt = F.log_softmax(new_sim_matrix, dim=-1)
mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1)
masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12
new_logpt = -torch.logsumexp(masked_logpt, dim=-1)
logpt_choice = torch.zeros_like(new_logpt)
mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2)
logpt_choice[mark_ind] = 1
sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=torch.uint8)).mean()
return sim_loss
class MaxMarginRankingLoss(nn.Module):
def __init__(self,
margin=1.0,
negative_weighting=False,
batch_size=1,
n_pair=1,
hard_negative_rate=0.5,
):
super(MaxMarginRankingLoss, self).__init__()
self.margin = margin
self.n_pair = n_pair
self.batch_size = batch_size
easy_negative_rate = 1 - hard_negative_rate
self.easy_negative_rate = easy_negative_rate
self.negative_weighting = negative_weighting
if n_pair > 1 and batch_size > 1:
alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate))
mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha
mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair)))
mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate))
self.mm_mask = mm_mask.float()
def forward(self, x):
d = torch.diag(x)
max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \
F.relu(self.margin + x - d.view(1, -1))
if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1:
max_margin = max_margin * self.mm_mask.to(max_margin.device)
return max_margin.mean()
class NormalizeVideo(nn.Module):
def __init__(self, task_config):
super(NormalizeVideo, self).__init__()
self.visual_norm2d = LayerNorm(task_config.video_dim)
def forward(self, video):
video = torch.as_tensor(video).float()
video = video.view(-1, video.shape[-2], video.shape[-1])
video = self.visual_norm2d(video)
return video
def show_log(task_config, info):
if task_config is None or task_config.local_rank == 0:
logger.warning(info)
def update_attr(target_name, target_config, target_attr_name, source_config, source_attr_name, default_value=None):
if hasattr(source_config, source_attr_name):
if default_value is None or getattr(source_config, source_attr_name) != default_value:
setattr(target_config, target_attr_name, getattr(source_config, source_attr_name))
show_log(source_config, "Set {}.{}: {}.".format(target_name,
target_attr_name, getattr(target_config, target_attr_name)))
return target_config
class VLBert(PreTrainedModel):
def __init__(self, bert_config, visual_config, cross_config, decoder_config, task_config):
super(VLBert, self).__init__(bert_config, visual_config, cross_config, decoder_config)
self.task_config = task_config
self.ignore_video_index = -1
assert self.task_config.max_words <= bert_config.max_position_embeddings
assert self.task_config.max_words <= decoder_config.max_target_embeddings
assert self.task_config.max_frames <= visual_config.max_position_embeddings
assert self.task_config.max_words + self.task_config.max_frames <= cross_config.max_position_embeddings
self._choice_sim = True
self._cross_model = False
if hasattr(self.task_config, 'cross_model') and self.task_config.cross_model:
self._choice_sim = False
self._cross_model = self.task_config.cross_model
self.train_sim_after_cross = False
if self._choice_sim and hasattr(self.task_config, 'train_sim_after_cross') and self.task_config.train_sim_after_cross:
self.train_sim_after_cross = True
show_log(task_config, "Test retrieval after cross encoder.")
self.with_sim_in_decoder = True
if hasattr(self.task_config, 'without_sim_in_decoder') and self.task_config.without_sim_in_decoder:
self.with_sim_in_decoder = False
show_log(task_config, "Set no sim after cross.")
self.pretrain_without_decoder = False
if hasattr(self.task_config, 'pretrain_without_decoder') and self.task_config.pretrain_without_decoder:
self.pretrain_without_decoder = True
show_log(task_config, "Set pretrain without decoder.")
show_log(task_config, "sim:{}, cross_model:{}".format(self._choice_sim, self._cross_model))
# The name should be `bert`, if dynamic vocabulary is needed
bert_config = update_attr("bert_config", bert_config, "num_hidden_layers",
self.task_config, "text_num_hidden_layers")
self.bert = BertModel(bert_config)
bert_word_embeddings_weight = self.bert.embeddings.word_embeddings.weight
bert_position_embeddings_weight = self.bert.embeddings.position_embeddings.weight
visual_config = update_attr("visual_config", visual_config, "num_hidden_layers",
self.task_config, "visual_num_hidden_layers")
self.visual = VisualBertModel(visual_config)
visual_word_embeddings_weight = self.visual.embeddings.word_embeddings.weight
if self._choice_sim is False or self.train_sim_after_cross:
self.cls = BertOnlyMLMHead(bert_config, bert_word_embeddings_weight)
self.cls_visual = VisualBertOnlyMLMHead(visual_config, visual_word_embeddings_weight)
cross_config = update_attr("cross_config", cross_config, "num_hidden_layers",
self.task_config, "cross_num_hidden_layers")
self.cross = CrossBertModel(cross_config)
if self.train_sim_after_cross is False:
if self.task_config.do_pretrain and self.pretrain_without_decoder:
show_log(task_config, "Pretrain without decoder.")
else:
decoder_config = update_attr("decoder_config", decoder_config, "num_decoder_layers",
self.task_config, "decoder_num_hidden_layers")
self.decoder = DecoderBertModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight)
self.normalize_video = NormalizeVideo(task_config)
mILNCELoss = MILNCELoss(batch_size=task_config.batch_size // task_config.n_gpu, n_pair=task_config.n_pair, )
maxMarginRankingLoss = MaxMarginRankingLoss(margin=task_config.margin,
negative_weighting=task_config.negative_weighting,
batch_size=task_config.batch_size // task_config.n_gpu,
n_pair=task_config.n_pair,
hard_negative_rate=task_config.hard_negative_rate, )
if task_config.use_mil:
self.loss_fct = mILNCELoss
show_log(task_config, "Using MILNCE.")
else:
self.loss_fct = maxMarginRankingLoss
show_log(task_config, "Using Ranking Loss.")
if self._cross_model:
self.loss_fct = CrossEn()
show_log(task_config, "Use CE Loss.")
if self._choice_sim is False or self.train_sim_after_cross:
self.similarity_dense = nn.Linear(bert_config.hidden_size, 1)
self.alm_loss_fct = CrossEntropyLoss(ignore_index=-1)
self.decoder_loss_fct = CrossEntropyLoss(ignore_index=-1)
self.pretrain_with_joint_sim = False
if hasattr(self.task_config, 'pretrain_with_joint_sim'):
if self.task_config.pretrain_with_joint_sim:
self.pretrain_with_joint_sim = True
show_log(task_config, "Set pretrain with joint embedding.")
if task_config.use_mil:
self._pretrain_sim_loss_fct = mILNCELoss
show_log(task_config, "Pretrain with joint embedding using MILNCE.")
else:
self._pretrain_sim_loss_fct = maxMarginRankingLoss
show_log(task_config, "Pretrain with joint embedding using Ranking Loss.")
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids, attention_mask, video, video_mask=None,
pairs_masked_text=None, pairs_token_labels=None, masked_video=None, video_labels_index=None,
input_caption_ids=None, decoder_mask=None, output_caption_ids=None):
input_ids = input_ids.view(-1, input_ids.shape[-1])
token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1])
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
video_mask = video_mask.view(-1, video_mask.shape[-1])
video = self.normalize_video(video)
if input_caption_ids is not None:
input_caption_ids = input_caption_ids.view(-1, input_caption_ids.shape[-1])
decoder_mask = decoder_mask.view(-1, decoder_mask.shape[-1])
sequence_output, visual_output = self.get_sequence_visual_output(input_ids, token_type_ids,
attention_mask, video, video_mask, shaped=True)
if self.training:
loss = 0.
if self._choice_sim:
sim_matrix = self.get_similarity_logits(sequence_output, visual_output, attention_mask, video_mask, shaped=True)
sim_loss = self.loss_fct(sim_matrix)
loss = loss + sim_loss
if self._cross_model and pairs_masked_text is not None and pairs_token_labels is not None:
if self.task_config.do_pretrain:
pairs_masked_text = pairs_masked_text.view(-1, pairs_masked_text.shape[-1])
pairs_token_labels = pairs_token_labels.view(-1, pairs_token_labels.shape[-1])
masked_video = self.normalize_video(masked_video)
video_labels_index = video_labels_index.view(-1, video_labels_index.shape[-1])
sequence_output_alm, visual_output_alm = self.get_sequence_visual_output(pairs_masked_text, token_type_ids, attention_mask, masked_video, video_mask, shaped=True)
cross_output, pooled_output, concat_mask = self._get_cross_output(sequence_output_alm, visual_output_alm, attention_mask, video_mask)
sequence_cross_output, visual_cross_output = torch.split(cross_output, [attention_mask.size(-1), video_mask.size(-1)], dim=1)
alm_loss = self._calculate_mlm_loss(sequence_cross_output, pairs_token_labels) # For mlm
loss = loss + alm_loss
nce_loss = self._calculate_mfm_loss(visual_cross_output, video, video_mask, video_labels_index) # For mfm
loss += nce_loss
if self.pretrain_with_joint_sim:
sim_matrix = self.get_similarity_logits(sequence_output, visual_output, attention_mask, video_mask,
shaped=True, _pretrain_joint=True)
sim_loss_joint = self._pretrain_sim_loss_fct(sim_matrix)
loss = loss + sim_loss_joint
if (input_caption_ids is not None) and \
((self.task_config.do_pretrain and self.pretrain_without_decoder is False)
or (self.task_config.do_pretrain is False and self.task_config.task_type == "caption")):
if self.task_config.do_pretrain:
decoder_scores, res_tuples = self._get_decoder_score(sequence_output_alm, visual_output_alm,
input_ids, attention_mask, video_mask,
input_caption_ids, decoder_mask, shaped=True)
elif self.task_config.task_type == "caption":
decoder_scores, res_tuples = self._get_decoder_score(sequence_output, visual_output,
input_ids, attention_mask, video_mask,
input_caption_ids, decoder_mask, shaped=True)
output_caption_ids = output_caption_ids.view(-1, output_caption_ids.shape[-1])
decoder_loss = self.decoder_loss_fct(decoder_scores.view(-1, self.bert_config.vocab_size), output_caption_ids.view(-1))
loss = loss + decoder_loss
if (self.with_sim_in_decoder and self.task_config.do_pretrain) or \
self.task_config.task_type == "retrieval":
if self.task_config.do_pretrain:
sim_matrix_text_visual = self.get_similarity_logits(sequence_output_alm, visual_output_alm, attention_mask,
video_mask, shaped=True)
elif self.task_config.task_type == "retrieval":
sim_matrix_text_visual = self.get_similarity_logits(sequence_output, visual_output, attention_mask,
video_mask, shaped=True)
sim_loss_text_visual = self.loss_fct(sim_matrix_text_visual)
loss = loss + sim_loss_text_visual
return loss
else:
return None
def _calculate_mlm_loss(self, sequence_output_alm, pairs_token_labels):
alm_scores = self.cls(sequence_output_alm)
alm_loss = self.alm_loss_fct(alm_scores.view(-1, self.bert_config.vocab_size), pairs_token_labels.view(-1))
return alm_loss
def _calculate_mfm_loss(self, visual_output_alm, video, video_mask, video_labels_index):
afm_scores = self.cls_visual(visual_output_alm)
afm_scores_tr = afm_scores.view(-1, afm_scores.shape[-1])
video_tr = video.permute(2, 0, 1)
video_tr = video_tr.view(video_tr.shape[0], -1)
logits_matrix = torch.mm(afm_scores_tr, video_tr)
video_mask_float = video_mask.to(dtype=torch.float)
mask_matrix = torch.mm(video_mask_float.view(-1, 1), video_mask_float.view(1, -1))
masked_logits = logits_matrix + (1. - mask_matrix) * -1e8
logpt = F.log_softmax(masked_logits, dim=-1)
logpt = torch.diag(logpt)
nce_loss = -logpt
video_labels_index_mask = (video_labels_index != self.ignore_video_index)
nce_loss = nce_loss.masked_select(video_labels_index_mask.view(-1))
nce_loss = nce_loss.mean()
return nce_loss
# For inference
def get_sequence_visual_output(self, input_ids, token_type_ids, attention_mask, video, video_mask, shaped=False):
if shaped is False:
input_ids = input_ids.view(-1, input_ids.shape[-1])
token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1])
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
video_mask = video_mask.view(-1, video_mask.shape[-1])
video = self.normalize_video(video)
encoded_layers, _ = self.bert("bi", input_ids, token_type_ids, attention_mask, output_all_encoded_layers=True)
sequence_output = encoded_layers[-1]
visual_layers, _ = self.visual("bi", video, video_mask, output_all_encoded_layers=True)
visual_output = visual_layers[-1]
# cannot add any cross code here
# beacause sequence and visual features are exclusive
return sequence_output, visual_output
def _mean_pooling_for_similarity(self, sequence_output, visual_output, attention_mask, video_mask,):
attention_mask_un = attention_mask.to(dtype=torch.float).unsqueeze(-1)
attention_mask_un[:, 0, :] = 0.
sequence_output = sequence_output * attention_mask_un
text_out = torch.sum(sequence_output, dim=1) / torch.sum(attention_mask_un, dim=1, dtype=torch.float)
video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1)
visual_output = visual_output * video_mask_un
video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float)
video_mask_un_sum[video_mask_un_sum == 0.] = 1.
video_out = torch.sum(visual_output, dim=1) / video_mask_un_sum
return text_out, video_out
def _cross_similarity(self, sequence_output, visual_output, attention_mask, video_mask):
b_text, s_text, h_text = sequence_output.size()
b_visual, s_visual, h_visual = visual_output.size()
retrieve_logits_list = []
step_size = 5
split_size = [step_size] * (b_text // step_size)
release_size = b_text - sum(split_size)
if release_size > 0:
split_size += [release_size]
sequence_output_splits = torch.split(sequence_output, split_size, dim=0)
attention_mask_splits = torch.split(attention_mask, split_size, dim=0)
for i in range(len(split_size)):
sequence_output_row = sequence_output_splits[i]
attention_mask_row = attention_mask_splits[i]
sequence_output_l = sequence_output_row.unsqueeze(1).repeat(1, b_visual, 1, 1)
sequence_output_l = sequence_output_l.view(-1, s_text, h_text)
attention_mask_l = attention_mask_row.unsqueeze(1).repeat(1, b_visual, 1)
attention_mask_l = attention_mask_l.view(-1, s_text)
step_truth = sequence_output_row.size(0)
visual_output_r = visual_output.unsqueeze(0).repeat(step_truth, 1, 1, 1)
visual_output_r = visual_output_r.view(-1, s_visual, h_visual)
video_mask_r = video_mask.unsqueeze(0).repeat(step_truth, 1, 1)
video_mask_r = video_mask_r.view(-1, s_visual)
cross_output, pooled_output, concat_mask = \
self._get_cross_output(sequence_output_l, visual_output_r, attention_mask_l, video_mask_r)
retrieve_logits_row = self.similarity_dense(pooled_output).squeeze(-1).view(step_truth, b_visual)
retrieve_logits_list.append(retrieve_logits_row)
retrieve_logits = torch.cat(retrieve_logits_list, dim=0)
return retrieve_logits
# For inference
def get_similarity_logits(self, sequence_output, visual_output, attention_mask, video_mask, shaped=False, _pretrain_joint=False):
if shaped is False:
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
video_mask = video_mask.view(-1, video_mask.shape[-1])
if self._cross_model and _pretrain_joint is False:
retrieve_logits = self._cross_similarity(sequence_output, visual_output, attention_mask, video_mask)
else:
if self.train_sim_after_cross:
retrieve_logits = self._cross_similarity(sequence_output, visual_output, attention_mask, video_mask)
else:
text_out, video_out = self._mean_pooling_for_similarity(sequence_output, visual_output, attention_mask, video_mask)
# Do a cosine simlarity
if self.task_config.use_mil is False:
text_out = F.normalize(text_out, dim=-1)
video_out = F.normalize(video_out, dim=-1)
retrieve_logits = torch.matmul(text_out, video_out.t())
return retrieve_logits
def _get_cross_output(self, sequence_output, visual_output, attention_mask, video_mask):
# Generate one pair (x, y) ===>
concat_features = torch.cat((sequence_output, visual_output), dim=1) # concatnate tokens and frames
concat_mask = torch.cat((attention_mask, video_mask), dim=1)
text_type_ = torch.zeros_like(attention_mask)
video_type_ = torch.ones_like(video_mask)
concat_type = torch.cat((text_type_, video_type_), dim=1)
# <===
cross_layers, pooled_output = self.cross("bi", concat_features, concat_type, concat_mask, output_all_encoded_layers=True)
cross_output = cross_layers[-1]
return cross_output, pooled_output, concat_mask
def _get_decoder_score(self, sequence_output, visual_output, input_ids, attention_mask, video_mask, input_caption_ids, decoder_mask, shaped=False):
if shaped is False:
input_ids = input_ids.view(-1, input_ids.shape[-1])
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
video_mask = video_mask.view(-1, video_mask.shape[-1])
input_caption_ids = input_caption_ids.view(-1, input_caption_ids.shape[-1])
decoder_mask = decoder_mask.view(-1, decoder_mask.shape[-1])
res_tuples = ()
cross_output, pooled_output, concat_mask = self._get_cross_output(sequence_output, visual_output, attention_mask, video_mask)
decoder_scores = self.decoder(input_caption_ids, input_ids, encoder_outs=cross_output, answer_mask=decoder_mask, encoder_mask=concat_mask)
return decoder_scores, res_tuples
def decoder_caption(self, sequence_output, visual_output, input_ids, attention_mask, video_mask, input_caption_ids, decoder_mask,
shaped=False, get_logits=False):
"""
:param get_logits: Used in Beam Search
:return:
"""
if shaped is False:
input_ids = input_ids.view(-1, input_ids.shape[-1])
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
video_mask = video_mask.view(-1, video_mask.shape[-1])
input_caption_ids = input_caption_ids.view(-1, input_caption_ids.shape[-1])
decoder_mask = decoder_mask.view(-1, decoder_mask.shape[-1])
assert self.pretrain_without_decoder is False
decoder_scores, _ = self._get_decoder_score(sequence_output, visual_output,
input_ids, attention_mask, video_mask, input_caption_ids, decoder_mask, shaped=True)
if get_logits:
return decoder_scores
_, decoder_scores_result = torch.max(decoder_scores, -1)
return decoder_scores_result

Просмотреть файл

@ -0,0 +1,180 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for BERT model."""
import math
import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_
import logging
logger = logging.getLogger(__name__)
def warmup_cosine(x, warmup=0.002):
if x < warmup:
return x/warmup
return 0.5 * (1.0 + torch.cos(math.pi * x))
def warmup_constant(x, warmup=0.002):
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
Learning rate is 1. afterwards. """
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(x, warmup=0.002):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if x < warmup:
return x/warmup
return max((x-1.)/(warmup-1.), 0)
SCHEDULES = {
'warmup_cosine': warmup_cosine,
'warmup_constant': warmup_constant,
'warmup_linear': warmup_linear,
}
class BertAdam(Optimizer):
"""Implements BERT version of Adam algorithm with weight decay fix.
Params:
lr: learning rate
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate. Default: -1
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
b1: Adams b1. Default: 0.9
b2: Adams b2. Default: 0.999
e: Adams epsilon. Default: 1e-6
weight_decay: Weight decay. Default: 0.01
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
"""
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
max_grad_norm=1.0):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
if not e >= 0.0:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(BertAdam, self).__init__(params, defaults)
def get_lr(self):
lr = []
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
return [0]
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
lr.append(lr_scheduled)
return lr
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
warned_for_t_total = False
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['next_m'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['next_v'] = torch.zeros_like(p.data)
next_m, next_v = state['next_m'], state['next_v']
beta1, beta2 = group['b1'], group['b2']
# Add grad clipping
if group['max_grad_norm'] > 0:
clip_grad_norm_(p, group['max_grad_norm'])
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
next_m.mul_(beta1).add_(1 - beta1, grad)
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
update = next_m / (next_v.sqrt() + group['e'])
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
logger.warning(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
warned_for_t_total = True
# end warning
else:
lr_scheduled = group['lr']
update_with_lr = lr_scheduled * update
p.data.add_(-update_with_lr)
state['step'] += 1
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
# No bias correction
# bias_correction1 = 1 - beta1 ** state['step']
# bias_correction2 = 1 - beta2 ** state['step']
return loss

Просмотреть файл

@ -0,0 +1,408 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import unicodedata
import os
import sys
import logging
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'base-uncased': 512,
'large-uncased': 512,
'base-cased': 512,
'large-cased': 512,
'base-multilingual-uncased': 512,
'base-multilingual-cased': 512,
'base-chinese': 512,
}
VOCAB_NAME = 'vocab.txt'
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r", encoding="utf-8") as reader:
while True:
token = reader.readline()
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a peice of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class BertTokenizer(object):
"""Runs end-to-end tokenization: punctuation splitting"""
def __init__(self, vocab_file, do_lower_case=True, max_len=None, never_split=("[UNK]", "[SEP]", "[MASK]", "[CLS]")):
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()])
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
for token in tokens:
if token not in self.vocab:
ids.append(self.vocab["[UNK]"])
logger.error("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
else:
ids.append(self.vocab[token])
if len(ids) > self.max_len:
raise ValueError(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids):
"""Converts a sequence of ids in tokens using the vocab."""
tokens = []
for i in ids:
tokens.append(self.ids_to_tokens[i])
return tokens
@classmethod
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
vocab_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
if os.path.exists(vocab_file) is False:
if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
else:
vocab_file = pretrained_model_name
if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
# redirect to the cache, if necessary
print(vocab_file)
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found. "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
kwargs['never_split'] = ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
return tokenizer
def add_tokens(self, new_tokens, model):
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary.
Args:
new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Returns:
Number of tokens added to the vocabulary.
Examples::
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
print('We have added', num_added_toks, 'tokens')
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
"""
to_add_tokens = []
for token in new_tokens:
assert isinstance(token, str)
to_add_tokens.append(token)
# logger.info("Adding %s to the vocabulary", token)
vocab = collections.OrderedDict()
for token in self.vocab.keys():
vocab[token] = self.vocab[token]
for token in to_add_tokens:
vocab[token] = len(vocab)
self.vocab = self.wordpiece_tokenizer.vocab = vocab
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()])
model.resize_token_embeddings(new_num_tokens=len(vocab))
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
self.never_split = never_split
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case and token not in self.never_split:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
if text in self.never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer`.
Returns:
A list of wordpiece tokens.
"""
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False

Просмотреть файл

@ -0,0 +1,12 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 1,
"vocab_size": 1024
}

Просмотреть файл

@ -0,0 +1,662 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'visual-bert-base': "[TODO]",
'visual-bert-large': "[TODO]",
}
CONFIG_NAME = 'visual_bert_config.json'
WEIGHTS_NAME = 'visual_pytorch_model.bin'
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class VisualBertConfig(object):
"""Configuration class to store the configuration of a `VisualBertModel`.
"""
def __init__(self,
vocab_size_or_config_json_file=4096,
hidden_size=768,
num_hidden_layers=3,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
initializer_range=0.02):
"""Constructs VisualBertConfig.
Args:
vocab_size_or_config_json_file: Size of the encoder layers and the pooler layer.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `VisualBertConfig` from a Python dictionary of parameters."""
config = VisualBertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `VisualBertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as VisualBertLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class VisualBertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(VisualBertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class VisualBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(VisualBertEmbeddings, self).__init__()
self.word_embeddings = nn.Linear(config.vocab_size, config.hidden_size)
# self.transform_act_fn = ACT2FN[config.hidden_act] \
# if isinstance(config.hidden_act, str) else config.hidden_act
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_embeddings):
seq_length = input_embeddings.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_embeddings.device)
position_ids = position_ids.unsqueeze(0).expand(input_embeddings.size(0), -1)
words_embeddings = self.word_embeddings(input_embeddings)
# words_embeddings = self.transform_act_fn(words_embeddings)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class VisualBertSelfAttention(nn.Module):
def __init__(self, config):
super(VisualBertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class VisualBertSelfOutput(nn.Module):
def __init__(self, config):
super(VisualBertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class VisualBertAttention(nn.Module):
def __init__(self, config):
super(VisualBertAttention, self).__init__()
self.self = VisualBertSelfAttention(config)
self.output = VisualBertSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output
class VisualBertIntermediate(nn.Module):
def __init__(self, config):
super(VisualBertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class VisualBertOutput(nn.Module):
def __init__(self, config):
super(VisualBertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class VisualBertLayer(nn.Module):
def __init__(self, config):
super(VisualBertLayer, self).__init__()
self.attention = VisualBertAttention(config)
self.intermediate = VisualBertIntermediate(config)
self.output = VisualBertOutput(config)
def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class VisualBertEncoder(nn.Module):
def __init__(self, config):
super(VisualBertEncoder, self).__init__()
layer = VisualBertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class VisualBertPooler(nn.Module):
def __init__(self, config):
super(VisualBertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class VisualBertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(VisualBertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class VisualBertLMPredictionHead(nn.Module):
"""
Note: It is a trap for `visual_bert_model_embedding_weights`, which is different from Embedding of BERT
"""
def __init__(self, config, visual_bert_model_embedding_weights):
super(VisualBertLMPredictionHead, self).__init__()
self.transform = VisualBertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.weight = visual_bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(visual_bert_model_embedding_weights.size(1)))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = hidden_states.matmul(self.weight) + self.bias
return hidden_states
class VisualBertOnlyMLMHead(nn.Module):
def __init__(self, config, visual_bert_model_embedding_weights):
super(VisualBertOnlyMLMHead, self).__init__()
self.predictions = VisualBertLMPredictionHead(config, visual_bert_model_embedding_weights)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class VisualBertOnlyNSPHead(nn.Module):
def __init__(self, config):
super(VisualBertOnlyNSPHead, self).__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class VisualBertPreTrainingHeads(nn.Module):
def __init__(self, config, visual_bert_model_embedding_weights):
super(VisualBertPreTrainingHeads, self).__init__()
self.predictions = VisualBertLMPredictionHead(config, visual_bert_model_embedding_weights)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class PreTrainedVisualBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedVisualBertModel, self).__init__()
if not isinstance(config, VisualBertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `VisualBertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_visual_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, VisualBertLayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def get_config(cls, pretrained_model_name, cache_dir, state_dict, task_config=None):
'''
abstract from `from_pretrained`
'''
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
if os.path.exists(archive_file) is False:
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
if task_config is None or task_config.local_rank == 0:
logger.error(
"Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
archive_file))
return None
if resolved_archive_file == archive_file:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {}".format(archive_file))
else:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
if task_config is None or task_config.local_rank == 0:
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = VisualBertConfig.from_json_file(config_file)
if task_config is None or task_config.local_rank == 0:
logger.info("Model config {}".format(config))
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
if os.path.exists(weights_path):
state_dict = torch.load(weights_path)
else:
if task_config is None or task_config.local_rank == 0:
logger.info("Weight doesn't exsits. {}".format(weights_path))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return config, state_dict
@classmethod
def init_preweight(cls, model, state_dict):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'visual_bert') else 'visual_bert.')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
if len(error_msgs) > 0:
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
return model
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None,
cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedVisualBertModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `visual-bert-base`
. `visual-bert-large`
- a path or url to a pretrained model archive containing:
. `visual_bert_config.json` a configuration file for the model
. `visual_pytorch_model.bin` a PyTorch dump of a VisualBertForPreTraining instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific VisualBert class
(ex: num_classes for VisualBertForSequenceClassification)
"""
config, state_dict = PreTrainedVisualBertModel.get_config(pretrained_model_name, cache_dir, state_dict)
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
return model
model = PreTrainedVisualBertModel.init_preweight(model, state_dict)
return model
class VisualBertModel(PreTrainedVisualBertModel):
"""VisualBert model ("Bidirectional Embedding Representations from a Transformer").
Params:
config: a VisualBertConfig class instance with the configuration to build a new model
Inputs:
`type`: a str, indicates which masking will be used in the attention, choice from [`bi`, `seq`, `gen`]
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see Bert paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for VisualBert-base, 24 for VisualBert-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see Bert's paper).
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
config = modeling.VisualBertConfig(vocab_size_or_config_json_file=4096, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.VisualBertModel(config=config)
all_encoder_layers, pooled_output = model(video, video_mask)
```
"""
def __init__(self, config):
super(VisualBertModel, self).__init__(config)
self.embeddings = VisualBertEmbeddings(config)
self.encoder = VisualBertEncoder(config)
self.pooler = VisualBertPooler(config)
self.apply(self.init_visual_bert_weights)
def forward(self, type, video, attention_mask=None, output_all_encoded_layers=True,
attention_sentenceA_mask=None, attention_sentenceB_mask=None):
if attention_mask is None:
attention_mask = torch.ones(video.size(0), video.size(1))
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(video)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output

529
train_retrieval_task.py Normal file
Просмотреть файл

@ -0,0 +1,529 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
import torch
from torch.utils.data import (RandomSampler, SequentialSampler, TensorDataset)
# from torch.utils.data import DataLoader
from data_prefetch_unitl import DataLoaderX as DataLoader # Enhanced Loader
import numpy as np
import random
import os
from youcook_nopair_dataloader import Youcook_NoPair_DataLoader
from metrics import compute_metrics, print_computed_metrics
import pickle
import logging
import time
import argparse
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import VLBert
from pytorch_pretrained_bert.optimization_bert import BertAdam
from util import parallel_apply
def get_logger(filename=None):
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
if filename is not None:
handler = logging.FileHandler(filename)
handler.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logging.getLogger().addHandler(handler)
return logger
global logger
def get_args(description='Youtube-Text-Video'):
parser = argparse.ArgumentParser(description=description)
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--max_words', type=int, default=20, help='')
parser.add_argument('--max_frames', type=int, default=100, help='')
parser.add_argument('--min_words', type=int, default=0, help='')
parser.add_argument('--feature_framerate', type=int, default=1, help='')
parser.add_argument('--min_time', type=float, default=5.0, help='Gather small clips')
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
parser.add_argument('--youcook', type=int, default=0, help='Train on YouCook2 data')
parser.add_argument('--eval_youcook', type=int, default=0, help='Evaluate on YouCook2 data')
parser.add_argument('--youcook_train_csv', type=str, default='data/youcookii_singlef_train.csv', help='')
parser.add_argument('--youcook_val_csv', type=str, default='data/youcookii_singlef_val.csv', help='')
parser.add_argument('--youcook_caption_path', type=str, default='data/youcookii_caption.pickle', help='youcookii caption pickle file path')
parser.add_argument('--youcook_features_path_2D', type=str, default='data/youcookii_videos_feature2d', help='youcookii 2D feature path')
parser.add_argument('--youcook_features_path_3D', type=str, default='data/youcookii_videos_feature3d', help='youcookii 3D feature path')
parser.add_argument('--pad_token', type=str, default='[PAD]', help='')
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--visual_bert_model", default="visual-bert-base", type=str, required=False,
help="VisualBert pre-trained model selected in the list: visual-bert-base, visual-bert-large")
parser.add_argument("--cross_bert_model", default="cross-bert-base", type=str, required=False,
help="CrossBert pre-trained model selected in the list: cross-bert-base, cross-bert-large")
parser.add_argument("--decoder_bert_model", default="decoder-bert-base", type=str, required=False,
help="DecoderBert pre-trained model selected in the list: decoder-bert-base, decoder-bert-large")
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.");
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--task_type", default="retrieval", type=str, help="Point the task `retrieval` to finetune.")
parser.add_argument("--datatype", default="youcook", type=str, help="Point the dataset `youcook` to finetune.")
parser.add_argument('--coef_lr', type=float, default=0.1, help='coefficient for bert branch.')
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether MIL, has a high priority than use_mil.")
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
parser.add_argument('--visual_num_hidden_layers', type=int, default=1, help="Layer NO. of visual.")
parser.add_argument('--cross_num_hidden_layers', type=int, default=2, help="Layer NO. of cross.")
parser.add_argument('--decoder_num_hidden_layers', type=int, default=1, help="Layer NO. of decoder.")
parser.add_argument('--train_sim_after_cross', action='store_true', help="Test retrieval after cross encoder.")
args = parser.parse_args()
if args.sampled_use_mil: # sample from each video, has a high priority than use_mil.
args.use_mil = True
# Check paramenters
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
if not args.do_pretrain and not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_pretrain` or `do_train` or `do_eval` must be True.")
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
return args
def set_seed_logger(args):
global logger
# predefining random initial seeds
random.seed(args.seed)
os.environ['PYTHONHASHSEED'] = str(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
logger.info("Effective parameters:")
for key in sorted(args.__dict__):
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
def init_device(args):
global logger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
logger.info("device: {} n_gpu: {}".format(device, n_gpu))
args.n_gpu = n_gpu
if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0:
raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format(
args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu))
return device, n_gpu
def init_model(args, device, n_gpu):
if args.init_model:
model_state_dict = torch.load(args.init_model, map_location='cpu')
else:
model_state_dict = None
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
# remove the need for this code, but it is still valid.
if args.fp16:
try:
import apex
apex.amp.register_half_function(torch, 'einsum')
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
return model
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, coef_lr=1.0):
if hasattr(model, 'module'):
model = model.module
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
no_decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)]
decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)]
no_decay_bert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." in n]
no_decay_nobert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." not in n]
decay_bert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." in n]
decay_nobert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." not in n]
optimizer_grouped_parameters = [
{'params': [p for n, p in no_decay_bert_param_tp], 'weight_decay': 0.01, 'lr': args.lr * coef_lr},
{'params': [p for n, p in no_decay_nobert_param_tp], 'weight_decay': 0.01},
{'params': [p for n, p in decay_bert_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
{'params': [p for n, p in decay_nobert_param_tp], 'weight_decay': 0.0}
]
scheduler = None
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
schedule='warmup_linear', t_total=num_train_optimization_steps, weight_decay=0.01,
max_grad_norm=1.0)
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if n_gpu > 1:
model = torch.nn.DataParallel(model)
return optimizer, scheduler, model
def dataloader_youcook_train(args, tokenizer):
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
logger.info('Done, caption length: {}'.format(len(caption)))
args.n_pair = 1
youcook_dataset = Youcook_NoPair_DataLoader(
csv=args.youcook_train_csv,
features_path=args.youcook_features_path_2D,
features_path_3D=args.youcook_features_path_3D,
caption=caption,
min_time=args.min_time,
max_words=args.max_words,
min_words=args.min_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
n_pair=args.n_pair,
pad_token=args.pad_token,
max_frames=args.max_frames,
)
dataloader = DataLoader(
youcook_dataset,
batch_size=args.batch_size,
num_workers=args.num_thread_reader,
shuffle=True,
drop_last=True,
)
return dataloader, len(youcook_dataset)
def dataloader_youcook_test(args, tokenizer):
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
logger.info('Done, caption length: {}'.format(len(caption)))
youcook_testset = Youcook_NoPair_DataLoader(
csv=args.youcook_val_csv,
features_path=args.youcook_features_path_2D,
features_path_3D=args.youcook_features_path_3D,
caption=caption,
min_time=args.min_time,
max_words=args.max_words,
min_words=args.min_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
n_pair=-1,
pad_token=args.pad_token,
max_frames=args.max_frames,
)
test_sampler = SequentialSampler(youcook_testset)
dataloader_youcook = DataLoader(
youcook_testset,
sampler=test_sampler,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
pin_memory=False,
)
logger.info('YoucookII validation pairs: {}'.format(len(youcook_testset)))
return dataloader_youcook, len(youcook_testset)
def save_model(epoch, args, model, type_name=""):
# Only save the model it-self
model_to_save = model.module if hasattr(model, 'module') else model
output_model_file = os.path.join(
args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
torch.save(model_to_save.state_dict(), output_model_file)
logger.info("Model saved to %s", output_model_file)
return output_model_file
def load_model(epoch, args, n_gpu, device, model_file=None):
if model_file is None or len(model_file) == 0:
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
if os.path.exists(model_file):
model_state_dict = torch.load(model_file, map_location='cpu')
logger.info("Model loaded from %s", model_file)
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
else:
model = None
return model
def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step):
global logger
torch.cuda.empty_cache()
model.train()
log_step = 100
start_time = time.time()
total_loss = 0
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
for step, batch in enumerate(train_dataloader):
if n_gpu == 1:
# multi-gpu does scattering it-self
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, video, video_mask, \
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index = batch
loss = model(input_ids, segment_ids, input_mask, video, video_mask,
pairs_masked_text=pairs_masked_text, pairs_token_labels=pairs_token_labels,
masked_video=masked_video, video_labels_index=video_labels_index)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
total_loss += float(loss)
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if scheduler is not None:
scheduler.step() # Update learning rate schedule
optimizer.step()
optimizer.zero_grad()
global_step += 1
if global_step % log_step == 0:
logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
args.epochs, step + 1,
len(train_dataloader), "-".join([str('%.6f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]), loss,
(time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
start_time = time.time()
total_loss = total_loss / len(train_dataloader)
return total_loss, global_step
def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list):
sim_matrix = []
for idx1, b1 in enumerate(batch_list_t):
input_ids, input_mask, segment_ids, _, _, _, _, _, _ = b1
sequence_output = batch_sequence_output_list[idx1]
each_row = []
for idx2, b2 in enumerate(batch_list_v):
_, _, _, video, video_mask, _, _, _, _ = b2
visual_output = batch_visual_output_list[idx2]
b1b2_logits = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask)
b1b2_logits = b1b2_logits.cpu().detach().numpy()
each_row.append(b1b2_logits)
each_row = np.concatenate(tuple(each_row), axis=-1)
sim_matrix.append(each_row)
return sim_matrix
def eval_epoch(args, model, test_dataloader, device, n_gpu):
if hasattr(model, 'module'):
model = model.module.to(device)
else:
model = model.to(device)
model.eval()
with torch.no_grad():
batch_list = []
batch_sequence_output_list, batch_visual_output_list = [], []
for bid, batch in enumerate(test_dataloader):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, video, video_mask, _, _, _, _ = batch
sequence_output, visual_output = model.get_sequence_visual_output(input_ids, segment_ids, input_mask, video, video_mask)
batch_sequence_output_list.append(sequence_output)
batch_visual_output_list.append(visual_output)
batch_list.append(batch)
print("{}/{}\r".format(bid, len(test_dataloader)), end="")
start_time = time.time()
if n_gpu > 1:
device_ids = list(range(n_gpu))
batch_list_t_splits = []
batch_list_v_splits = []
batch_t_output_splits = []
batch_v_output_splits = []
bacth_len = len(batch_list)
split_len = (bacth_len + n_gpu - 1) // n_gpu
for dev_id in device_ids:
s_, e_ = dev_id * split_len, (dev_id + 1) * split_len
if dev_id == 0:
batch_list_t_splits.append(batch_list[s_:e_])
batch_list_v_splits.append(batch_list)
batch_t_output_splits.append(batch_sequence_output_list[s_:e_])
batch_v_output_splits.append(batch_visual_output_list)
else:
devc = torch.device('cuda:{}'.format(str(dev_id)))
devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list[s_:e_]]
batch_list_t_splits.append(devc_batch_list)
devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list]
batch_list_v_splits.append(devc_batch_list)
devc_batch_list = [b.to(devc) for b in batch_sequence_output_list[s_:e_]]
batch_t_output_splits.append(devc_batch_list)
devc_batch_list = [b.to(devc) for b in batch_visual_output_list]
batch_v_output_splits.append(devc_batch_list)
parameters_tuple_list = [(batch_list_t_splits[dev_id], batch_list_v_splits[dev_id],
batch_t_output_splits[dev_id], batch_v_output_splits[dev_id]) for dev_id in device_ids]
parallel_outputs = parallel_apply(_run_on_single_gpu, model, parameters_tuple_list, device_ids)
sim_matrix = []
for idx in range(len(parallel_outputs)):
sim_matrix += parallel_outputs[idx]
else:
sim_matrix = _run_on_single_gpu(model, batch_list, batch_list, batch_sequence_output_list, batch_visual_output_list)
logger.info("%f" % (time.time() - start_time))
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
metrics = compute_metrics(sim_matrix)
logger.info('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0])))
logger.info('\t>>> R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.
format(metrics['R1'], metrics['R5'], metrics['R10'], metrics['MR']))
R1 = metrics['R1']
return R1
DATALOADER_DICT = {}
DATALOADER_DICT["youcook"] = {"train":dataloader_youcook_train, "val":dataloader_youcook_test}
def main():
global logger
args = get_args()
set_seed_logger(args)
device, n_gpu = init_device(args)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
if args.do_train:
args.task_type = "retrieval"
model = init_model(args, device, n_gpu)
datatype = args.datatype
assert datatype in DATALOADER_DICT
test_dataloader, test_length = DATALOADER_DICT[datatype]["val"](args, tokenizer)
logger.info("***** Running test *****")
logger.info(" Num examples = %d", test_length)
logger.info(" Batch size = %d", args.batch_size_val)
logger.info(" Num steps = %d", len(test_dataloader))
if args.do_pretrain:
raise NotImplementedError("Deleted~ no use")
elif args.do_train:
train_dataloader, train_length = DATALOADER_DICT[datatype]["train"](args, tokenizer)
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
/ args.gradient_accumulation_steps) * args.epochs
coef_lr = 0.1
if args.init_model:
coef_lr = 1.0
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, coef_lr=coef_lr)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", train_length)
logger.info(" Batch size = %d", args.batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps)
best_score = 0.00001
best_output_model_file = None
global_step = 0
for epoch in range(args.epochs):
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
scheduler, global_step)
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
output_model_file = save_model(epoch, args, model, type_name="")
R1 = eval_epoch(args, model, test_dataloader, device, n_gpu)
if best_score <= R1:
best_score = R1
best_output_model_file = output_model_file
logger.info("The best model is: {}, the R1 is: {:.4f}".format(best_output_model_file, best_score))
model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
eval_epoch(args, model, test_dataloader, device, n_gpu)
elif args.do_eval:
eval_epoch(args, model, test_dataloader, device, n_gpu)
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,850 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
import torch
from torch.utils.data import (RandomSampler, SequentialSampler, TensorDataset)
from data_prefetch_unitl import DataLoaderX as DataLoader # Enhanced Loader
import numpy as np
import random
import os
from collections import OrderedDict
from youcook_transcript_nopair_dataloader import Youcook_Transcript_NoPair_DataLoader
from youtube_transcript_dataloader import Youtube_Transcript_DataLoader
from rouge import Rouge
from nlgeval import compute_metrics, NLGEval
import pickle
import logging
import time
import argparse
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import VLBert
from pytorch_pretrained_bert.optimization_bert import BertAdam
from pytorch_pretrained_bert.beam import Beam
torch.distributed.init_process_group(backend="nccl")
global amp_pck_loaded_
try:
from apex import amp
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
# remove the need for this code, but it is still valid.
amp.register_half_function(torch, 'einsum')
amp_pck_loaded_ = True
except ImportError:
amp_pck_loaded_ = False
def get_logger(filename=None):
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
if filename is not None:
handler = logging.FileHandler(filename)
handler.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logging.getLogger().addHandler(handler)
return logger
global logger
def get_args(description='Youtube-Text-Video'):
parser = argparse.ArgumentParser(description=description)
parser.add_argument('--train_csv', type=str, default='data/HowTo100M_v1.csv', help='train csv')
parser.add_argument('--features_path_2D', type=str, default='feature_2d', help='feature path for 2D features')
parser.add_argument('--features_path_3D', type=str, default='feature_3d', help='feature path for 3D features')
parser.add_argument('--caption_path', type=str, default='data/caption.pickle', help='caption pickle file path')
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--max_words', type=int, default=20, help='')
parser.add_argument('--max_frames', type=int, default=100, help='')
parser.add_argument('--min_words', type=int, default=0, help='')
parser.add_argument('--feature_framerate', type=int, default=1, help='')
parser.add_argument('--min_time', type=float, default=5.0, help='Gather small clips')
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
parser.add_argument('--youcook', type=int, default=0, help='Train on YouCook2 data')
parser.add_argument('--eval_youcook', type=int, default=0, help='Evaluate on YouCook2 data')
parser.add_argument('--youcook_train_csv', type=str, default='data/youcookii_singlef_train.csv', help='')
parser.add_argument('--youcook_val_csv', type=str, default='data/youcookii_singlef_val.csv', help='')
parser.add_argument('--youcook_caption_path', type=str, default='data/youcookii_caption_transcript.pickle', help='youcookii caption and transcription pickle file path')
parser.add_argument('--youcook_features_path_2D', type=str, default='data/youcookii_videos_feature2d', help='youcookii feature path for 2D features')
parser.add_argument('--youcook_features_path_3D', type=str, default='data/youcookii_videos_feature3d', help='youcookii feature path for 3D features')
parser.add_argument('--pad_token', type=str, default='[PAD]', help='')
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--visual_bert_model", default="visual-bert-base", type=str, required=False,
help="VisualBert pre-trained model selected in the list: visual-bert-base, visual-bert-large")
parser.add_argument("--cross_bert_model", default="cross-bert-base", type=str, required=False,
help="CrossBert pre-trained model selected in the list: cross-bert-base, cross-bert-large")
parser.add_argument("--decoder_bert_model", default="decoder-bert-base", type=str, required=False,
help="DecoderBert pre-trained model selected in the list: decoder-bert-base, decoder-bert-large")
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.");
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--task_type", default="caption", type=str, help="Point the task `retrieval` or `caption` to finetune.")
parser.add_argument("--datatype", default="youcook", type=str, help="Point the dataset `youcook` to finetune.")
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
parser.add_argument('--coef_lr', type=float, default=0.1, help='coefficient for bert branch.')
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether use MIL, has a high priority than use_mil.")
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
parser.add_argument('--visual_num_hidden_layers', type=int, default=1, help="Layer NO. of visual.")
parser.add_argument('--cross_num_hidden_layers', type=int, default=2, help="Layer NO. of cross.")
parser.add_argument('--decoder_num_hidden_layers', type=int, default=1, help="Layer NO. of decoder.")
parser.add_argument('--cross_model', action='store_true', help="Whether training with decoder.")
parser.add_argument('--pretrain_with_joint_sim', action='store_true', help="Whether using joint embedding when pretraining.")
parser.add_argument('--pretrain_enhance_vmodal', action='store_true', help="Enhance visual and other modalities when pretraining.")
parser.add_argument('--without_sim_in_decoder', action='store_true', help="Whether align in decoder when training.")
parser.add_argument('--pretrain_without_decoder', action='store_true', help="Whether ignore decode when pretraining.")
parser.add_argument("--load_checkpoint", action="store_true")
parser.add_argument("--checkpoint_model", default="pytorch_model.bin.checkpoint", type=str, required=False,
help="Save the last model as a checkpoint.")
args = parser.parse_args()
if args.sampled_use_mil: # sample from each video, has a high priority than use_mil.
args.use_mil = True
if args.do_pretrain is False:
args.pretrain_without_decoder = False
global amp_pck_loaded_
if args.fp16:
if not amp_pck_loaded_:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
# Check paramenters
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
if not args.do_pretrain and not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_pretrain` or `do_train` or `do_eval` must be True.")
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
args.checkpoint_model = '{}_{}_{}_{}.checkpoint'.format(args.checkpoint_model, args.bert_model, args.max_words, args.max_frames)
return args
def set_seed_logger(args):
global logger
# predefining random initial seeds
random.seed(args.seed)
os.environ['PYTHONHASHSEED'] = str(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# torch.multiprocessing.set_sharing_strategy('file_system') # RuntimeError: received 0 items of ancdata.
world_size = torch.distributed.get_world_size()
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
args.local_rank = local_rank
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
if local_rank == 0:
logger.info("Effective parameters:")
for key in sorted(args.__dict__):
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
return args, world_size, local_rank
def init_device(args, local_rank):
global logger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank)
n_gpu = torch.cuda.device_count()
logger.info("device: {} n_gpu: {}".format(device, n_gpu))
args.n_gpu = n_gpu
if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0:
raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format(
args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu))
return device, n_gpu
def init_model(args, device, n_gpu, local_rank):
if args.init_model:
model_state_dict = torch.load(args.init_model, map_location='cpu')
else:
model_state_dict = None
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
return model
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.):
if hasattr(model, 'module'):
model = model.module
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
no_decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)]
decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)]
no_decay_bert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." in n]
no_decay_nobert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." not in n]
decay_bert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." in n]
decay_nobert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." not in n]
optimizer_grouped_parameters = [
{'params': [p for n, p in no_decay_bert_param_tp], 'weight_decay': 0.01, 'lr': args.lr * coef_lr},
{'params': [p for n, p in no_decay_nobert_param_tp], 'weight_decay': 0.01},
{'params': [p for n, p in decay_bert_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
{'params': [p for n, p in decay_nobert_param_tp], 'weight_decay': 0.0}
]
scheduler = None
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
schedule='warmup_linear', t_total=num_train_optimization_steps, weight_decay=0.01,
max_grad_norm=1.0)
if args.fp16:
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if n_gpu > 1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
output_device=local_rank, find_unused_parameters=True)
return optimizer, scheduler, model
def dataloader_youcook_train(args, tokenizer):
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
logger.info('Done, caption length: {}'.format(len(caption)))
youcook_dataset = Youcook_Transcript_NoPair_DataLoader(
csv=args.youcook_train_csv,
features_path=args.youcook_features_path_2D,
features_path_3D=args.youcook_features_path_3D,
caption=caption,
min_time=args.min_time,
max_words=args.max_words,
min_words=args.min_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
n_pair=-1,
pad_token=args.pad_token,
max_frames=args.max_frames,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(youcook_dataset)
dataloader = DataLoader(
youcook_dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=True,
)
return dataloader, len(youcook_dataset), train_sampler
def dataloader_youcook_test(args, tokenizer):
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
logger.info('Done, caption length: {}'.format(len(caption)))
youcook_testset = Youcook_Transcript_NoPair_DataLoader(
csv=args.youcook_val_csv,
features_path=args.youcook_features_path_2D,
features_path_3D=args.youcook_features_path_3D,
caption=caption,
min_time=args.min_time,
max_words=args.max_words,
min_words=args.min_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
n_pair=-1,
pad_token=args.pad_token,
max_frames=args.max_frames,
)
test_sampler = SequentialSampler(youcook_testset)
dataloader_youcook = DataLoader(
youcook_testset,
sampler=test_sampler,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
pin_memory=False,
)
logger.info('YoucookII validation pairs: {}'.format(len(youcook_testset)))
return dataloader_youcook, len(youcook_testset)
def dataloader_pretrain(args, tokenizer, only_sim=False):
logger.info('Loading captions: {}'.format(args.caption_path))
caption = pickle.load(open(args.caption_path, 'rb'))
logger.info('Done, caption length: {}'.format(len(caption)))
dataset = Youtube_Transcript_DataLoader(
csv=args.train_csv,
features_path=args.features_path_2D,
features_path_3D=args.features_path_3D,
caption=caption,
min_time=args.min_time,
max_words=args.max_words,
min_words=args.min_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
n_pair=args.n_pair,
pad_token=args.pad_token,
max_frames=args.max_frames,
use_mil=args.use_mil,
only_sim=only_sim,
sampled_use_mil=args.sampled_use_mil,
pretrain_enhance_vmodal=args.pretrain_enhance_vmodal,
video_dim=args.video_dim,
)
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=(sampler is None),
sampler=sampler,
drop_last=True,
)
return dataloader, len(dataset), sampler
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict
def save_model(epoch, args, model, local_rank, type_name="", global_step=-1, optimizer=None):
# Only save the model it-self
model_to_save = model.module if hasattr(model, 'module') else model
output_model_file = os.path.join(
args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
torch.save(model_to_save.state_dict(), output_model_file)
logger.info("Model saved to %s", output_model_file)
if global_step != -1 and optimizer is not None:
amp_state_dict = {}
if args.fp16:
amp_state_dict = amp.state_dict()
state_dict = {
'epoch': epoch,
'global_step': global_step,
'model_state_dict': model_to_save.state_dict(),
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'amp_state_dict': amp_state_dict,
}
checkpoint_model_file = os.path.join(args.output_dir, args.checkpoint_model)
torch.save(state_dict, checkpoint_model_file)
logger.info("Checkpoint is saved. use `load_checkpoint` to recovery it.")
return output_model_file
def load_model(epoch, args, n_gpu, device, model, global_step=0, model_file=None):
if model_file is None or len(model_file) == 0:
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
last_optim_state = None
amp_state_dict = None
checkpoint_model_file = os.path.join(args.output_dir, args.checkpoint_model)
if epoch == -1 and args.load_checkpoint and os.path.exists(checkpoint_model_file):
checkpoint_state = torch.load(checkpoint_model_file, map_location='cpu')
epoch = checkpoint_state['epoch']
global_step = checkpoint_state['global_step']
model_state_dict = checkpoint_state['model_state_dict']
last_optim_state = checkpoint_state['last_optimizer_state']
if args.fp16 and 'amp_state_dict' in checkpoint_state:
amp_state_dict = checkpoint_state['amp_state_dict']
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
logger.info("Checkpoint loaded from %s", checkpoint_model_file)
elif os.path.exists(model_file):
model_state_dict = torch.load(model_file, map_location='cpu')
logger.info("Model loaded from %s", model_file)
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
return epoch, global_step, last_optim_state, amp_state_dict, model
def train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer, scheduler,
global_step, nlgEvalObj=None, local_rank=0):
global logger
torch.cuda.empty_cache()
model.train()
log_step = args.n_display
start_time = time.time()
total_loss = 0
for step, batch in enumerate(train_dataloader):
# if n_gpu == 1:
# # multi-gpu does scattering it-self
# batch = tuple(t.to(device) for t in batch)
batch = tuple(t.to(device=device, non_blocking=True) for t in batch)
input_ids, input_mask, segment_ids, video, video_mask, \
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index,\
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids = batch
loss = model(input_ids, segment_ids, input_mask, video, video_mask,
pairs_masked_text=pairs_masked_text, pairs_token_labels=pairs_token_labels,
masked_video=masked_video, video_labels_index=video_labels_index,
input_caption_ids=pairs_input_caption_ids, decoder_mask=pairs_decoder_mask,
output_caption_ids=pairs_output_caption_ids)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
total_loss += float(loss)
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if scheduler is not None:
scheduler.step() # Update learning rate schedule
optimizer.step()
optimizer.zero_grad()
global_step += 1
if global_step % log_step == 0 and local_rank == 0:
logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
args.epochs, step + 1,
len(train_dataloader), "-".join([str('%.6f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]),
float(loss) * args.gradient_accumulation_steps,
(time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
start_time = time.time()
total_loss = total_loss / len(train_dataloader)
return total_loss, global_step
# ---------------------------------------->
def get_inst_idx_to_tensor_position_map(inst_idx_list):
''' Indicate the position of an instance in a tensor. '''
return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}
def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
''' Collect tensor parts associated to active instances. '''
_, *d_hs = beamed_tensor.size()
n_curr_active_inst = len(curr_active_inst_idx)
new_shape = (n_curr_active_inst * n_bm, *d_hs)
beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
beamed_tensor = beamed_tensor.view(*new_shape)
return beamed_tensor
def collate_active_info(input_tuples, inst_idx_to_position_map, active_inst_idx_list, n_bm, device):
assert isinstance(input_tuples, tuple)
sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt = input_tuples
# Sentences which are still active are collected,
# so the decoder will not run on completed sentences.
n_prev_active_inst = len(inst_idx_to_position_map)
active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
active_inst_idx = torch.LongTensor(active_inst_idx).to(device)
active_sequence_output_rpt = collect_active_part(sequence_output_rpt, active_inst_idx, n_prev_active_inst, n_bm)
active_visual_output_rpt = collect_active_part(visual_output_rpt, active_inst_idx, n_prev_active_inst, n_bm)
active_input_ids_rpt = collect_active_part(input_ids_rpt, active_inst_idx, n_prev_active_inst, n_bm)
active_input_mask_rpt = collect_active_part(input_mask_rpt, active_inst_idx, n_prev_active_inst, n_bm)
active_video_mask_rpt = collect_active_part(video_mask_rpt, active_inst_idx, n_prev_active_inst, n_bm)
active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
return (active_sequence_output_rpt, active_visual_output_rpt, active_input_ids_rpt, active_input_mask_rpt, active_video_mask_rpt), \
active_inst_idx_to_position_map
def beam_decode_step(decoder, inst_dec_beams, len_dec_seq,
inst_idx_to_position_map, n_bm, device, input_tuples, decoder_length=None):
assert isinstance(input_tuples, tuple)
''' Decode and update beam status, and then return active beam idx'''
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
dec_partial_seq = torch.stack(dec_partial_seq).to(device)
dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
return dec_partial_seq
def predict_word(next_decoder_ids, n_active_inst, n_bm, device, input_tuples):
sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt = input_tuples
next_decoder_mask = torch.ones(next_decoder_ids.size(), dtype=torch.uint8).to(device)
dec_output = decoder(sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt,
video_mask_rpt, next_decoder_ids, next_decoder_mask, shaped=True, get_logits=True)
dec_output = dec_output[:, -1, :]
word_prob = torch.nn.functional.log_softmax(dec_output, dim=1)
word_prob = word_prob.view(n_active_inst, n_bm, -1)
return word_prob
def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map, decoder_length=None):
active_inst_idx_list = []
for inst_idx, inst_position in inst_idx_to_position_map.items():
if decoder_length is None:
is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
else:
is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position], word_length=decoder_length[inst_idx])
if not is_inst_complete:
active_inst_idx_list += [inst_idx]
return active_inst_idx_list
n_active_inst = len(inst_idx_to_position_map)
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
word_prob = predict_word(dec_seq, n_active_inst, n_bm, device, input_tuples)
# Update the beam with predicted word prob information and collect incomplete instances
active_inst_idx_list = collect_active_inst_idx_list(inst_dec_beams, word_prob, inst_idx_to_position_map,
decoder_length=decoder_length)
return active_inst_idx_list
def collect_hypothesis_and_scores(inst_dec_beams, n_best):
all_hyp, all_scores = [], []
for inst_idx in range(len(inst_dec_beams)):
scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
all_scores += [scores[:n_best]]
hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
all_hyp += [hyps]
return all_hyp, all_scores
# >----------------------------------------
def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=None, nlgEvalObj=None, test_set=None):
if hasattr(model, 'module'):
model = model.module.to(device)
if model._choice_sim:
return 0.
all_result_lists = []
all_caption_lists = []
model.eval()
for batch in test_dataloader:
batch = tuple(t.to(device, non_blocking=True) for t in batch)
input_ids, input_mask, segment_ids, video, video_mask, \
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index, \
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids = batch
with torch.no_grad():
sequence_output, visual_output = model.get_sequence_visual_output(input_ids, segment_ids, input_mask, video, video_mask)
# -- Repeat data for beam search
n_bm = 5 # beam_size
device = sequence_output.device
n_inst, len_s, d_h = sequence_output.size()
_, len_v, v_h = visual_output.size()
decoder = model.decoder_caption # This is a decoder function
# Note: shaped first, then decoder need the parameter shaped=True
input_ids = input_ids.view(-1, input_ids.shape[-1]) # [batch_size*n_pair, self.max_words]
input_mask = input_mask.view(-1, input_mask.shape[-1]) # [batch_size*n_pair, self.max_words]
video_mask = video_mask.view(-1, video_mask.shape[-1]) # [batch_size*n_pair, self.max_frames]
sequence_output_rpt = sequence_output.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)
visual_output_rpt = visual_output.repeat(1, n_bm, 1).view(n_inst * n_bm, len_v, v_h)
input_ids_rpt = input_ids.repeat(1, n_bm).view(n_inst * n_bm, len_s)
input_mask_rpt = input_mask.repeat(1, n_bm).view(n_inst * n_bm, len_s)
video_mask_rpt = video_mask.repeat(1, n_bm).view(n_inst * n_bm, len_v)
# -- Prepare beams
inst_dec_beams = [Beam(n_bm, device=device, tokenizer=tokenizer) for _ in range(n_inst)]
# -- Bookkeeping for active or not
active_inst_idx_list = list(range(n_inst))
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
# -- Decode
for len_dec_seq in range(1, args.max_words + 1):
active_inst_idx_list = beam_decode_step(decoder, inst_dec_beams,
len_dec_seq, inst_idx_to_position_map, n_bm, device,
(sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt))
if not active_inst_idx_list:
break # all instances have finished their path to <EOS>
(sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt), \
inst_idx_to_position_map = collate_active_info((sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt),
inst_idx_to_position_map, active_inst_idx_list, n_bm, device)
batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, 1)
result_list = [batch_hyp[i][0] for i in range(n_inst)]
pairs_output_caption_ids = pairs_output_caption_ids.view(-1, pairs_output_caption_ids.shape[-1])
caption_list = pairs_output_caption_ids.cpu().detach().numpy()
for re_idx, re_list in enumerate(result_list):
decode_text_list = tokenizer.convert_ids_to_tokens(re_list)
if "[SEP]" in decode_text_list:
SEP_index = decode_text_list.index("[SEP]")
decode_text_list = decode_text_list[:SEP_index]
if "[PAD]" in decode_text_list:
PAD_index = decode_text_list.index("[PAD]")
decode_text_list = decode_text_list[:PAD_index]
decode_text = ' '.join(decode_text_list)
decode_text = decode_text.replace(" ##", "").strip("##").strip()
all_result_lists.append(decode_text)
for re_idx, re_list in enumerate(caption_list):
decode_text_list = tokenizer.convert_ids_to_tokens(re_list)
if "[SEP]" in decode_text_list:
SEP_index = decode_text_list.index("[SEP]")
decode_text_list = decode_text_list[:SEP_index]
if "[PAD]" in decode_text_list:
PAD_index = decode_text_list.index("[PAD]")
decode_text_list = decode_text_list[:PAD_index]
decode_text = ' '.join(decode_text_list)
decode_text = decode_text.replace(" ##", "").strip("##").strip()
all_caption_lists.append(decode_text)
# Save full results
if test_set is not None and hasattr(test_set, 'iter2video_pairs_dict'):
hyp_path = os.path.join(args.output_dir, "hyp_complete_results.txt")
with open(hyp_path, "w", encoding='utf-8') as writer:
writer.write("{}\t{}\t{}\n".format("video_id", "start_time", "caption"))
for idx, pre_txt in enumerate(all_result_lists):
video_id, sub_id = test_set.iter2video_pairs_dict[idx]
start_time = test_set.caption[video_id]['start'][sub_id]
writer.write("{}\t{}\t{}\n".format(video_id, start_time, pre_txt))
logger.info("File of complete results is saved in {}".format(hyp_path))
# Save pure results
hyp_path = os.path.join(args.output_dir, "hyp.txt")
with open(hyp_path, "w", encoding='utf-8') as writer:
for pre_txt in all_result_lists:
writer.write(pre_txt+"\n")
ref_path = os.path.join(args.output_dir, "ref.txt")
with open(ref_path, "w", encoding='utf-8') as writer:
for ground_txt in all_caption_lists:
writer.write(ground_txt + "\n")
# Filter out hyps of 0 length
hyps_and_refs = zip(all_result_lists, all_caption_lists)
hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0
and len([" ".join(sub_s.split()) for sub_s in _[0].split(".") if len(sub_s) > 0]) > 0]
all_result_lists, all_caption_lists = zip(*hyps_and_refs)
# Evaluate
metrics_dict = rougeObj.get_scores(hyps=all_result_lists, refs=all_caption_lists, avg=True, ignore_empty=True)
logger.info(">>> rouge_1f: {:.4f}, rouge_2f: {:.4f}, rouge_lf: {:.4f}".
format(metrics_dict["rouge-1"]["f"], metrics_dict["rouge-2"]["f"], metrics_dict["rouge-l"]["f"]))
metrics_nlg = nlgEvalObj.compute_metrics(ref_list=[all_caption_lists], hyp_list=all_result_lists)
logger.info(">>> BLEU_1: {:.4f}, BLEU_2: {:.4f}, BLEU_3: {:.4f}, BLEU_4: {:.4f}".
format(metrics_nlg["Bleu_1"], metrics_nlg["Bleu_2"], metrics_nlg["Bleu_3"], metrics_nlg["Bleu_4"]))
logger.info(">>> METEOR: {:.4f}, ROUGE_L: {:.4f}, CIDEr: {:.4f}".format(metrics_nlg["METEOR"], metrics_nlg["ROUGE_L"], metrics_nlg["CIDEr"]))
Bleu_4 = metrics_nlg["Bleu_4"]
return Bleu_4
DATALOADER_DICT = {}
DATALOADER_DICT["youcook"] = {"train":dataloader_youcook_train, "val":dataloader_youcook_test}
def main():
global logger
args = get_args()
args, world_size, local_rank = set_seed_logger(args)
device, n_gpu = init_device(args, local_rank)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
if args.do_train:
args.task_type = "caption"
model = init_model(args, device, n_gpu, local_rank)
only_sim = model.module._choice_sim if hasattr(model, 'module') else model._choice_sim
if args.task_type == "caption":
rougeObj = Rouge()
nlgEvalObj = NLGEval(no_overlap=False, no_skipthoughts=True, no_glove=True, metrics_to_omit=None)
datatype = "youcook"
assert datatype in DATALOADER_DICT
if args.do_pretrain is False:
test_dataloader, test_length = DATALOADER_DICT[datatype]["val"](args, tokenizer)
if local_rank == 0:
logger.info("***** Running test *****")
logger.info(" Num examples = %d", test_length)
logger.info(" Batch size = %d", args.batch_size_val)
logger.info(" Num steps = %d", len(test_dataloader))
if args.do_pretrain:
train_dataloader, train_length, sampler = dataloader_pretrain(args, tokenizer, only_sim=only_sim)
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
/ args.gradient_accumulation_steps) * args.epochs
global_step = 0
epoch = -1
last_optim_state = None
amp_state_dict = None
if args.load_checkpoint:
epoch, global_step, last_optim_state, amp_state_dict, model = load_model(epoch, args, n_gpu, device, model, global_step=global_step)
epoch += 1
if local_rank == 0:
logger.warning("Will continue to epoch: {}".format(epoch))
epoch = 0 if epoch < 0 else epoch
coef_lr = args.coef_lr
if args.init_model:
coef_lr = 1.0
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=coef_lr)
if last_optim_state is not None:
optimizer.load_state_dict(last_optim_state)
if amp_state_dict is not None:
amp.load_state_dict(amp_state_dict)
if local_rank == 0:
logger.info("***** Running pretraining *****")
logger.info(" Num examples = %d", train_length)
logger.info(" Batch size = %d", args.batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
iter_ls_ = [itm for itm in range(args.epochs) if itm >= epoch]
for epoch in iter_ls_:
sampler.set_epoch(epoch)
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer,
scheduler, global_step, nlgEvalObj=nlgEvalObj, local_rank=local_rank)
if local_rank == 0:
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
save_model(epoch, args, model, local_rank, type_name="pretrain", global_step=global_step, optimizer=optimizer)
elif args.do_train:
train_dataloader, train_length, train_sampler = DATALOADER_DICT[datatype]["train"](args, tokenizer)
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
/ args.gradient_accumulation_steps) * args.epochs
coef_lr = args.coef_lr
if args.init_model:
coef_lr = 1.0
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=coef_lr)
if local_rank == 0:
logger.info("***** Running training *****")
logger.info(" Num examples = %d", train_length)
logger.info(" Batch size = %d", args.batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
best_score = 0.00001
best_output_model_file = None
global_step = 0
for epoch in range(args.epochs):
train_sampler.set_epoch(epoch)
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer,
scheduler, global_step, nlgEvalObj=nlgEvalObj, local_rank=local_rank)
if local_rank == 0:
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
output_model_file = save_model(epoch, args, model, local_rank, type_name="")
if epoch > 0:
Bleu_4 = eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=rougeObj, nlgEvalObj=nlgEvalObj)
if best_score <= Bleu_4:
best_score = Bleu_4
best_output_model_file = output_model_file
logger.info("The best model is: {}, the Bleu_4 is: {:.4f}".format(best_output_model_file, best_score))
else:
logger.warning("Skip the evaluation after {}-th epoch.".format(epoch+1))
if local_rank == 0:
_, _, _, _, model = load_model(-1, args, n_gpu, device, model, model_file=best_output_model_file)
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=rougeObj, nlgEvalObj=nlgEvalObj)
elif args.do_eval:
if local_rank == 0:
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=rougeObj, nlgEvalObj=nlgEvalObj)
if __name__ == "__main__":
main()

59
util.py Normal file
Просмотреть файл

@ -0,0 +1,59 @@
import torch
import torch.nn as nn
import threading
from torch._utils import ExceptionWrapper
def get_a_var(obj):
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, list) or isinstance(obj, tuple):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
def parallel_apply(fct, model, inputs, device_ids):
modules = nn.parallel.replicate(model, device_ids)
assert len(modules) == len(inputs)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input):
torch.set_grad_enabled(grad_enabled)
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
output = fct(module, *input)
with lock:
results[i] = output
except Exception:
with lock:
results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device))
if len(modules) > 1:
threads = [threading.Thread(target=_worker, args=(i, module, input))
for i, (module, input) in enumerate(zip(modules, inputs))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs

Просмотреть файл

@ -0,0 +1,210 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
from torch.utils.data import Dataset
import pandas as pd
import os
import numpy as np
import random
class Youcook_NoPair_DataLoader(Dataset):
"""Youcook dataset loader."""
def __init__(
self,
csv,
features_path,
caption,
tokenizer,
min_time=10.0,
features_path_3D=None,
feature_framerate=1.0,
feature_framerate_3D=24.0 / 16.0,
max_words=30,
min_words=0,
n_pair=-1, # -1 for test
pad_token="[PAD]",
max_frames=100,
):
"""
Args:
"""
self.csv = pd.read_csv(csv)
self.features_path_2D = features_path
self.features_path_3D = features_path_3D
self.caption = caption
self.min_time = min_time
self.feature_framerate = feature_framerate
self.feature_framerate_3D = feature_framerate_3D
self.max_words = max_words
self.max_frames = max_frames
self.min_words = min_words
self.tokenizer = tokenizer
self.n_pair = n_pair
self.pad_token = pad_token
self.fps = {'2d': feature_framerate, '3d': feature_framerate_3D}
self.feature_path = {'2d': features_path}
if features_path_3D != '':
self.feature_path['3d'] = features_path_3D
# Get iterator video ids
video_id_list = [itm for itm in self.csv['video_id'].values]
self.video_id2idx_dict = {video_id: id for id, video_id in enumerate(video_id_list)}
# Get all captions
self.iter2video_pairs_dict = {}
iter_idx_ = 0
for video_id in video_id_list:
caption = self.caption[video_id]
n_caption = len(caption['start'])
for sub_id in range(n_caption):
self.iter2video_pairs_dict[iter_idx_] = (video_id, sub_id)
iter_idx_ += 1
def __len__(self):
return len(self.iter2video_pairs_dict)
def _get_text(self, video_id, sub_id):
caption = self.caption[video_id]
k = 1
r_ind = [sub_id]
starts = np.zeros(k)
ends = np.zeros(k)
pairs_text = np.zeros((k, self.max_words), dtype=np.long)
pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
pairs_masked_text = np.zeros((k, self.max_words), dtype=np.long)
pairs_token_labels = np.zeros((k, self.max_words), dtype=np.long)
for i in range(k):
ind = r_ind[i]
# Note: n_pair_max=-1 means eval
words = self.tokenizer.tokenize(caption['text'][ind])
start_, end_ = caption['start'][ind], caption['end'][ind]
starts[i], ends[i] = start_, end_
words = ["[CLS]"] + words
total_length_with_CLS = self.max_words - 1
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
words = words + ["[SEP]"]
# Mask Language Model <-----
token_labels = []
masked_tokens = words.copy()
for token_id, token in enumerate(masked_tokens):
if token_id == 0 or token_id == len(masked_tokens) - 1:
token_labels.append(-1)
continue
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
prob /= 0.15
# 80% randomly change token to mask token
if prob < 0.8:
masked_tokens[token_id] = "[MASK]"
# 10% randomly change token to random token
elif prob < 0.9:
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
# -> rest 10% randomly keep current token
# append current token to output (we will predict these later)
try:
token_labels.append(self.tokenizer.vocab[token])
except KeyError:
# For unknown words (should not occur with BPE vocab)
token_labels.append(self.tokenizer.vocab["[UNK]"])
# print("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
else:
# no masking token (will be ignored by loss function later)
token_labels.append(-1)
# -----> Mask Language Model
input_ids = self.tokenizer.convert_tokens_to_ids(words)
input_mask = [1] * len(input_ids)
segment_ids = [0] * len(input_ids)
masked_token_ids = self.tokenizer.convert_tokens_to_ids(masked_tokens)
while len(input_ids) < self.max_words:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
masked_token_ids.append(0)
token_labels.append(-1)
assert len(input_ids) == self.max_words
assert len(input_mask) == self.max_words
assert len(segment_ids) == self.max_words
assert len(masked_token_ids) == self.max_words
assert len(token_labels) == self.max_words
pairs_text[i] = np.array(input_ids)
pairs_mask[i] = np.array(input_mask)
pairs_segment[i] = np.array(segment_ids)
pairs_masked_text[i] = np.array(masked_token_ids)
pairs_token_labels[i] = np.array(token_labels)
return pairs_text, pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels, starts, ends
def _get_video(self, idx, s, e):
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
max_video_length = [0] * len(s)
f_title = "feature_file_2D"
fps_k = "2d"
feature_file = os.path.join(self.feature_path[fps_k], self.csv[f_title].values[idx])
video_features = np.load(feature_file)
video = np.zeros((len(s), self.max_frames, video_features.shape[-1]), dtype=np.float)
for i in range(len(s)):
start = int(s[i] * self.fps[fps_k])
end = int(e[i] * self.fps[fps_k]) + 1
video_slice = video_features[start:end]
if self.max_frames < video_slice.shape[0]:
video_slice = video_slice[:self.max_frames]
slice_shape = video_slice.shape
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
if len(video_slice) < 1:
print("video_id: {}, start: {}, end: {}".format(feature_file, start, end))
else:
video[i][:slice_shape[0]] = video_slice
for i, v_length in enumerate(max_video_length):
video_mask[i][:v_length] = [1] * v_length
# Mask Frame Model <-----
video_labels_index = [[] for _ in range(len(s))]
masked_video = video.copy()
for i, video_pair_ in enumerate(masked_video):
for j, _ in enumerate(video_pair_):
if j < max_video_length[i]:
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
masked_video[i][j] = [0.] * video.shape[-1]
video_labels_index[i].append(j)
else:
video_labels_index[i].append(-1)
else:
video_labels_index[i].append(-1)
video_labels_index = np.array(video_labels_index, dtype=np.long)
# -----> Mask Frame Model
return video, video_mask, masked_video, video_labels_index
def __getitem__(self, feature_idx):
video_id, sub_id = self.iter2video_pairs_dict[feature_idx]
idx = self.video_id2idx_dict[video_id]
pairs_text, pairs_mask, pairs_segment, \
pairs_masked_text, pairs_token_labels, starts, ends = self._get_text(video_id, sub_id)
video, video_mask, masked_video, video_labels_index = self._get_video(idx, starts, ends)
return pairs_text, pairs_mask, pairs_segment, video, video_mask, \
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index

Просмотреть файл

@ -0,0 +1,246 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
from torch.utils.data import Dataset
import pandas as pd
import os
import numpy as np
import re
import random
import io
class Youcook_Transcript_NoPair_DataLoader(Dataset):
"""Youcook dataset loader."""
def __init__(
self,
csv,
features_path,
caption,
tokenizer,
min_time=10.0,
features_path_3D=None,
feature_framerate=1.0,
feature_framerate_3D=24.0 / 16.0,
max_words=30,
min_words=0,
n_pair=-1, # -1 for test
pad_token="[PAD]",
max_frames=100,
):
"""
Args:
"""
self.csv = pd.read_csv(csv)
self.features_path_2D = features_path
self.features_path_3D = features_path_3D
self.caption = caption
self.min_time = min_time
self.feature_framerate = feature_framerate
self.feature_framerate_3D = feature_framerate_3D
self.max_words = max_words
self.max_frames = max_frames
self.min_words = min_words
self.tokenizer = tokenizer
self.n_pair = n_pair
self.pad_token = pad_token
self.fps = {'2d': feature_framerate, '3d': feature_framerate_3D}
self.feature_path = {'2d': features_path}
if features_path_3D != '':
self.feature_path['3d'] = features_path_3D
_feature_file = os.path.join(features_path, self.csv["feature_file_2D"].values[0])
self.feature_size = np.load(_feature_file).shape[-1]
# Get iterator video ids
video_id_list = [itm for itm in self.csv['video_id'].values]
self.video_id2idx_dict = {video_id: id for id, video_id in enumerate(video_id_list)}
# Get all captions
self.iter2video_pairs_dict = {}
iter_idx_ = 0
for video_id in video_id_list:
caption = self.caption[video_id]
n_caption = len(caption['start'])
for sub_id in range(n_caption):
self.iter2video_pairs_dict[iter_idx_] = (video_id, sub_id)
iter_idx_ += 1
def __len__(self):
return len(self.iter2video_pairs_dict)
def _get_text(self, video_id, sub_id):
caption = self.caption[video_id]
k = 1
r_ind = [sub_id]
starts = np.zeros(k)
ends = np.zeros(k)
pairs_text = np.zeros((k, self.max_words), dtype=np.long)
pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
pairs_masked_text = np.zeros((k, self.max_words), dtype=np.long)
pairs_token_labels = np.zeros((k, self.max_words), dtype=np.long)
pairs_input_caption_ids = np.zeros((k, self.max_words), dtype=np.long)
pairs_output_caption_ids = np.zeros((k, self.max_words), dtype=np.long)
pairs_decoder_mask = np.zeros((k, self.max_words), dtype=np.long)
for i in range(k):
ind = r_ind[i]
start_, end_ = caption['start'][ind], caption['end'][ind]
starts[i], ends[i] = start_, end_
total_length_with_CLS = self.max_words - 1
words = self.tokenizer.tokenize(caption['transcript'][ind])
words = ["[CLS]"] + words
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
words = words + ["[SEP]"]
# Mask Language Model <-----
token_labels = []
masked_tokens = words.copy()
for token_id, token in enumerate(masked_tokens):
if token_id == 0 or token_id == len(masked_tokens) - 1:
token_labels.append(-1)
continue
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
prob /= 0.15
# 80% randomly change token to mask token
if prob < 0.8:
masked_tokens[token_id] = "[MASK]"
# 10% randomly change token to random token
elif prob < 0.9:
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
# -> rest 10% randomly keep current token
# append current token to output (we will predict these later)
try:
token_labels.append(self.tokenizer.vocab[token])
except KeyError:
# For unknown words (should not occur with BPE vocab)
token_labels.append(self.tokenizer.vocab["[UNK]"])
# print("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
else:
# no masking token (will be ignored by loss function later)
token_labels.append(-1)
# -----> Mask Language Model
input_ids = self.tokenizer.convert_tokens_to_ids(words)
masked_token_ids = self.tokenizer.convert_tokens_to_ids(masked_tokens)
input_mask = [1] * len(input_ids)
segment_ids = [0] * len(input_ids)
while len(input_ids) < self.max_words:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
masked_token_ids.append(0)
token_labels.append(-1)
assert len(input_ids) == self.max_words
assert len(input_mask) == self.max_words
assert len(segment_ids) == self.max_words
assert len(masked_token_ids) == self.max_words
assert len(token_labels) == self.max_words
pairs_text[i] = np.array(input_ids)
pairs_mask[i] = np.array(input_mask)
pairs_segment[i] = np.array(segment_ids)
pairs_masked_text[i] = np.array(masked_token_ids)
pairs_token_labels[i] = np.array(token_labels)
# For generate captions
caption_words = self.tokenizer.tokenize(caption['text'][ind])
if len(caption_words) > total_length_with_CLS:
caption_words = caption_words[:total_length_with_CLS]
input_caption_words = ["[CLS]"] + caption_words
output_caption_words = caption_words + ["[SEP]"]
# For generate captions
input_caption_ids = self.tokenizer.convert_tokens_to_ids(input_caption_words)
output_caption_ids = self.tokenizer.convert_tokens_to_ids(output_caption_words)
decoder_mask = [1] * len(input_caption_ids)
while len(input_caption_ids) < self.max_words:
input_caption_ids.append(0)
output_caption_ids.append(0)
decoder_mask.append(0)
assert len(input_caption_ids) == self.max_words
assert len(output_caption_ids) == self.max_words
assert len(decoder_mask) == self.max_words
pairs_input_caption_ids[i] = np.array(input_caption_ids)
pairs_output_caption_ids[i] = np.array(output_caption_ids)
pairs_decoder_mask[i] = np.array(decoder_mask)
return pairs_text, pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels,\
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids, starts, ends
def _get_video(self, idx, s, e):
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
max_video_length = [0] * len(s)
f_title = "feature_file_2D"
fps_k = "2d"
feature_file = os.path.join(self.feature_path[fps_k], self.csv[f_title].values[idx])
video_features = np.load(feature_file)
video = np.zeros((len(s), self.max_frames, self.feature_size), dtype=np.float)
for i in range(len(s)):
start = int(s[i] * self.fps[fps_k])
end = int(e[i] * self.fps[fps_k]) + 1
video_slice = video_features[start:end]
if self.max_frames < video_slice.shape[0]:
video_slice = video_slice[:self.max_frames]
slice_shape = video_slice.shape
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
if len(video_slice) < 1:
print("video_id: {}, start: {}, end: {}".format(feature_file, start, end))
# pass
else:
video[i][:slice_shape[0]] = video_slice
for i, v_length in enumerate(max_video_length):
video_mask[i][:v_length] = [1] * v_length
# Mask Frame Model <-----
video_labels_index = [[] for _ in range(len(s))]
masked_video = video.copy()
for i, video_pair_ in enumerate(masked_video):
for j, _ in enumerate(video_pair_):
if j < max_video_length[i]:
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
masked_video[i][j] = [0.] * video.shape[-1]
video_labels_index[i].append(j)
else:
video_labels_index[i].append(-1)
else:
video_labels_index[i].append(-1)
video_labels_index = np.array(video_labels_index, dtype=np.long)
# -----> Mask Frame Model
return video, video_mask, masked_video, video_labels_index
def __getitem__(self, feature_idx):
video_id, sub_id = self.iter2video_pairs_dict[feature_idx]
idx = self.video_id2idx_dict[video_id]
pairs_text, pairs_mask, pairs_segment, \
pairs_masked_text, pairs_token_labels, pairs_input_caption_ids, \
pairs_decoder_mask, pairs_output_caption_ids, starts, ends = self._get_text(video_id, sub_id)
video, video_mask, masked_video, video_labels_index = self._get_video(idx, starts, ends)
return pairs_text, pairs_mask, pairs_segment, video, video_mask, \
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index, \
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids

Просмотреть файл

@ -0,0 +1,419 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
from torch.utils.data import Dataset
import pandas as pd
import os
import numpy as np
import re
import random
import io
class Youtube_Transcript_DataLoader(Dataset):
"""
Youtube dataset loader.
Note: Use transcript as caption, for mask decoder pretrain task.
"""
def __init__(
self,
csv,
features_path,
caption,
tokenizer,
min_time=10.0,
features_path_3D=None,
feature_framerate=1.0,
feature_framerate_3D=24.0 / 16.0,
max_words=30,
min_words=0,
n_pair=-1, # -1 for test
pad_token="[PAD]",
max_frames=100,
with_long_context=True,
use_mil=False,
only_sim=False, # set automatically from model choice
sampled_use_mil=False,
pretrain_enhance_vmodal=False,
video_dim=1024,
):
"""
Args:
"""
self.csv = pd.read_csv(csv)
self.features_path_2D = features_path
self.features_path_3D = features_path_3D
self.caption = caption
self.min_time = min_time
self.feature_framerate = feature_framerate
self.feature_framerate_3D = feature_framerate_3D
self.max_words = max_words
self.max_frames = max_frames
self.min_words = min_words
self.tokenizer = tokenizer
self.n_pair = n_pair
self.pad_token = pad_token
self.fps = {'2d': feature_framerate, '3d': feature_framerate_3D}
self.feature_path = {'2d': features_path}
if features_path_3D != '':
self.feature_path['3d'] = features_path_3D
self.with_long_context = with_long_context
self.feature_size = video_dim
self.only_sim = only_sim
self.pretrain_enhance_vmodal = pretrain_enhance_vmodal
self.iter_num = len(self.csv)
self.use_mil = use_mil
self.sampled_use_mil = sampled_use_mil
if self.sampled_use_mil: # sample from each video, has a high priority than use_mil.
self.use_mil = True
if self.use_mil:
positive_n_pair = self.n_pair
# Get iterator video ids
video_id_list = [itm for itm in self.csv['video_id'].values]
self.video_id2idx_dict = {video_id: id for id, video_id in enumerate(video_id_list)}
# Get all captions
self.iter2video_pairs_dict = {}
self.iter2video_pairslist_dict = {}
iter_idx_mil_ = 0
for video_id in video_id_list:
caption = self.caption[video_id]
n_caption = len(caption['start'])
sub_list = []
if self.n_pair < 0 or self.n_pair == 1:
for sub_id in range(n_caption):
sub_list.append([sub_id])
else:
sb_ls_ = list(range(n_caption))
sb_st_ = set(sb_ls_)
if self.n_pair > n_caption:
sb_ls_ = sb_ls_ * (self.n_pair // n_caption + 1)
sb_ls_ = sb_ls_[:self.n_pair]
for sub_id in np.arange(0, len(sb_ls_), self.n_pair):
sub_list.append(sb_ls_[sub_id: sub_id + self.n_pair])
else:
sb_ls_ = sb_ls_ + sb_ls_[:(((n_caption+positive_n_pair-1)//positive_n_pair)*positive_n_pair-n_caption)]
for sub_id in np.arange(0, len(sb_ls_), positive_n_pair):
pos_ls = sb_ls_[sub_id: sub_id + positive_n_pair]
sub_list.append(pos_ls)
for sub_e in sub_list:
self.iter2video_pairs_dict[iter_idx_mil_] = (video_id, sub_e)
iter_idx_mil_ += 1
self.iter2video_pairslist_dict[video_id] = sub_list
if self.use_mil and self.sampled_use_mil is False:
self.iter_num = len(self.iter2video_pairs_dict)
def __len__(self):
return self.iter_num
def _mask_tokens(self, words, orig2token_tuple_list=None, chunk_positions=None):
token_labels = []
masked_tokens = words.copy()
if chunk_positions is not None and len(chunk_positions)>0:
token_labels = [-1] * len(masked_tokens)
for chunk_ind in chunk_positions:
start_, len_ = orig2token_tuple_list[chunk_ind]
if random.random() < 0.15:
token_labels[start_:start_+len_] = [self.tokenizer.vocab[tk_] for tk_ in masked_tokens[start_:start_+len_]]
masked_tokens[start_:start_+len_] = ["[MASK]"] * len_
else:
for token_id, token in enumerate(masked_tokens):
if token_id == 0 or token_id == len(masked_tokens) - 1:
token_labels.append(-1)
continue
prob = random.random()
if prob < 0.15:
prob /= 0.15
if prob < 0.8:
masked_tokens[token_id] = "[MASK]"
elif prob < 0.9:
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
try:
token_labels.append(self.tokenizer.vocab[token])
except KeyError:
token_labels.append(self.tokenizer.vocab["[UNK]"])
else:
token_labels.append(-1)
return masked_tokens, token_labels
def _get_text(self, video_id, n_pair_max, sub_ids=None, only_sim=False, enhance_vmodel=False):
caption = self.caption[video_id]
if self.use_mil:
k = len(sub_ids)
r_ind = sub_ids
else:
n_caption = len(caption['start'])
if n_pair_max == -1:
k = n_caption
r_ind = range(n_caption)
else:
k = n_pair_max
if k <= n_caption:
r_ind = np.random.choice(range(n_caption), k, replace=False)
else:
r_ind_must = np.array(range(n_caption))
r_ind_rand = np.random.choice(range(n_caption), k-n_caption, replace=True)
r_ind = np.concatenate((r_ind_must, r_ind_rand), axis=0)
np.random.shuffle(r_ind)
starts = np.zeros(k)
ends = np.zeros(k)
pairs_text = np.zeros((k, self.max_words), dtype=np.long)
pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
pairs_masked_text = np.zeros((k, self.max_words), dtype=np.long)
pairs_token_labels = np.zeros((k, self.max_words), dtype=np.long)
pairs_input_caption_ids = np.zeros((k, self.max_words), dtype=np.long)
pairs_output_caption_ids = np.zeros((k, self.max_words), dtype=np.long)
pairs_decoder_mask = np.zeros((k, self.max_words), dtype=np.long)
for i in range(k):
ind = r_ind[i]
words, start_, end_ = self._get_single_transcript(caption, ind, with_long_context=self.with_long_context)
caption_words = words.copy()
starts[i], ends[i] = start_, end_
orig2token_tuple_list, chunk_positions = None, None
# # For entity mask
# orig_sentence, orig_to_token_map = generate_org_sentence_from_tokens(words.copy())
# four_tuples_ = generate_knoledge_candidate(orig_sentence, orig_to_token_map)
# _, orig2token_tuple_list, chunk_positions, _ = four_tuples_
if enhance_vmodel:
words = [] # mask all input text
words = ["[CLS]"] + words
total_length_with_CLS = self.max_words - 1
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
words = words + ["[SEP]"]
input_ids = self.tokenizer.convert_tokens_to_ids(words)
input_mask = [1] * len(input_ids)
segment_ids = [0] * len(input_ids)
while len(input_ids) < self.max_words:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == self.max_words
assert len(input_mask) == self.max_words
assert len(segment_ids) == self.max_words
pairs_text[i] = np.array(input_ids)
pairs_mask[i] = np.array(input_mask)
pairs_segment[i] = np.array(segment_ids)
if only_sim is False:
# # For entity mask : +1 due to [CLS]
# orig2token_tuple_list = [(s_+1, len_) for s_, len_ in orig2token_tuple_list if s_+len_ <= total_length_with_CLS]
# chunk_positions = [ind for ind in chunk_positions if ind<len(orig2token_tuple_list)]
# For generate captions
if len(caption_words) > total_length_with_CLS:
caption_words = caption_words[:total_length_with_CLS]
input_caption_words = ["[CLS]"] + caption_words
output_caption_words = caption_words + ["[SEP]"]
masked_tokens, token_labels = self._mask_tokens(words, orig2token_tuple_list, chunk_positions)
masked_token_ids = self.tokenizer.convert_tokens_to_ids(masked_tokens)
masked_input_caption_words, input_token_labels = self._mask_tokens(input_caption_words)
input_caption_words = masked_input_caption_words.copy()
while len(masked_token_ids) < self.max_words:
masked_token_ids.append(0)
token_labels.append(-1)
assert len(masked_token_ids) == self.max_words
assert len(token_labels) == self.max_words
# For generate captions
input_caption_ids = self.tokenizer.convert_tokens_to_ids(input_caption_words)
output_caption_ids = self.tokenizer.convert_tokens_to_ids(output_caption_words)
decoder_mask = [1] * len(input_caption_ids)
while len(input_caption_ids) < self.max_words:
input_caption_ids.append(0)
output_caption_ids.append(0)
decoder_mask.append(0)
assert len(input_caption_ids) == self.max_words
assert len(output_caption_ids) == self.max_words
assert len(decoder_mask) == self.max_words
pairs_masked_text[i] = np.array(masked_token_ids)
pairs_token_labels[i] = np.array(token_labels)
pairs_input_caption_ids[i] = np.array(input_caption_ids)
pairs_output_caption_ids[i] = np.array(output_caption_ids)
pairs_decoder_mask[i] = np.array(decoder_mask)
return pairs_text, pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels, \
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids, starts, ends
def _get_single_transcript(self, caption, ind, with_long_context=True):
start, end = ind, ind
words = self.tokenizer.tokenize(str(caption['text'][ind]))
diff = caption['end'][end] - caption['start'][start]
while with_long_context and (len(words) < self.min_words or diff < self.min_time):
if start > 0 and end < len(caption['end']) - 1:
next_words = self.tokenizer.tokenize(str(caption['text'][end + 1]))
prev_words = self.tokenizer.tokenize(str(caption['text'][start - 1]))
d1 = caption['end'][end + 1] - caption['start'][start]
d2 = caption['end'][end] - caption['start'][start - 1]
if (self.min_time > 0 and d2 <= d1) or \
(self.min_time == 0 and len(next_words) <= len(prev_words)):
start -= 1
words = prev_words + words
else:
end += 1
words.extend(next_words)
elif start > 0:
words = self.tokenizer.tokenize(str(caption['text'][start - 1])) + words
start -= 1
elif end < len(caption['end']) - 1:
words.extend(self.tokenizer.tokenize(str(caption['text'][end + 1])))
end += 1
else:
break
diff = caption['end'][end] - caption['start'][start]
return words, caption['start'][start], caption['end'][end]
def _expand_video_slice(self, s, e, si, ei, fps, video_features, fps_k):
start = int(s[si] * fps)
end = int(e[ei] * fps) + 1
if start > end:
start, end = end, start
video_slice = video_features[start:end]
expand_left = True
while len(video_slice) < 1:
if si==0 and ei==len(s)-1:
break
if expand_left:
expand_left = False
si = si-1 if si>0 else si
else:
expand_left = True
ei = ei+1 if ei<len(e)-1 else ei
start = int(s[si] * fps)
end = int(e[ei] * fps) + 1
if start > end:
start, end = end, start
video_slice = video_features[start:end]
# # Note: to alignment the features between 2d and 3d, due to the fps is different.
# if fps_k == "2d":
# indices = sorted([idx for idx in range(len(video_slice)) if idx % 2 == 1]
# + [idx for idx in range(len(video_slice))])
# video_slice = np.take(video_slice, indices, axis=0)
if self.max_frames < video_slice.shape[0]:
video_slice = video_slice[:self.max_frames]
return video_slice, start, end
def _get_video(self, idx, s, e, only_sim=False):
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
max_video_length = [0] * len(s)
f_title = "feature_file_2D"
fps_k = "2d"
video = np.zeros((len(s), self.max_frames, self.feature_size), dtype=np.float)
feature_file = os.path.join(self.feature_path[fps_k], self.csv[f_title].values[idx])
try:
video_features = np.load(feature_file)
for i in range(len(s)):
# start = int(s[i] * self.fps[fps_k])
# end = int(e[i] * self.fps[fps_k]) + 1
# video_slice = video_features[start:end]
#
# if self.max_frames < video_slice.shape[0]:
# video_slice = video_slice[:self.max_frames]
if len(video_features) < 1:
raise ValueError("{} is empty.".format(feature_file))
video_slice, start, end = self._expand_video_slice(s, e, i, i,
self.fps[fps_k], video_features, fps_k)
slice_shape = video_slice.shape
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
if len(video_slice) < 1:
pass
# video_slice = video_features
# if self.max_frames < video_slice.shape[0]:
# video_slice = video_slice[:self.max_frames]
else:
video[i][:slice_shape[0]] = video_slice
except Exception as e:
print("video_id: {} error.".format(feature_file))
for i, v_length in enumerate(max_video_length):
video_mask[i][:v_length] = [1] * v_length
# Mask Frame Model <-----
video_labels_index = [[] for _ in range(len(s))]
masked_video = video.copy()
if only_sim is False:
for i, video_pair_ in enumerate(masked_video):
for j, _ in enumerate(video_pair_):
if j < max_video_length[i]:
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
masked_video[i][j] = [0.] * video.shape[-1]
video_labels_index[i].append(j)
else:
video_labels_index[i].append(-1)
else:
video_labels_index[i].append(-1)
video_labels_index = np.array(video_labels_index, dtype=np.long)
# -----> Mask Frame Model
return video, video_mask, masked_video, video_labels_index
def second_to_stamp(self, in_seconds):
m, s = divmod(in_seconds, 60)
h, m2 = divmod(m, 60)
return "%02d:%02d:%02d" % (h, m2, s)
def __getitem__(self, feature_idx):
if self.sampled_use_mil: # sample from each video, has a high priority than use_mil.
idx = feature_idx
video_id = self.csv['video_id'].values[idx]
sub_list = self.iter2video_pairslist_dict[video_id]
ranint = np.random.randint(0, len(sub_list))
sub_ids = sub_list[ranint]
elif self.use_mil:
video_id, sub_ids = self.iter2video_pairs_dict[feature_idx]
idx = self.video_id2idx_dict[video_id]
else:
idx = feature_idx
video_id = self.csv['video_id'].values[idx]
sub_ids = None
enhance_vmodel = False
if self.only_sim is False and self.pretrain_enhance_vmodal:
prob = random.random()
if prob < 0.15: # mask all text by rate 0.15
enhance_vmodel = True
pairs_text, pairs_mask, pairs_segment, \
pairs_masked_text, pairs_token_labels, pairs_input_caption_ids, \
pairs_decoder_mask, pairs_output_caption_ids, \
starts, ends = self._get_text(video_id, self.n_pair, sub_ids, only_sim=self.only_sim, enhance_vmodel=enhance_vmodel)
video, video_mask, masked_video, video_labels_index = self._get_video(idx, starts, ends, only_sim=self.only_sim)
return pairs_text, pairs_mask, pairs_segment, video, video_mask, \
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index, \
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids