This commit is contained in:
Родитель
b0ea315de1
Коммит
f08a73055d
|
@ -0,0 +1 @@
|
|||
.idea
|
21
LICENSE
21
LICENSE
|
@ -1,21 +0,0 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2020 ArrowLuo
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,91 @@
|
|||
|
||||
# Preliminary
|
||||
Excute below scripts in main folder firstly.
|
||||
```
|
||||
cd pytorch_pretrained_bert/bert-base-uncased/
|
||||
wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt
|
||||
mv bert-base-uncased-vocab.txt vocab.txt
|
||||
wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz
|
||||
tar -xvf bert-base-uncased.tar.gz
|
||||
rm bert-base-uncased.tar.gz
|
||||
cd ../../
|
||||
```
|
||||
|
||||
# Finetune on YoucookII
|
||||
## Retrieval
|
||||
```
|
||||
INIT_MODEL=<from second phase>
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_retrieval_task.py \
|
||||
--do_train --num_thread_reader=16 \
|
||||
--epochs=5 --batch_size=32 \
|
||||
--n_pair=1 --n_display=100 \
|
||||
--youcook_train_csv data/youcookii_train.csv \
|
||||
--youcook_val_csv data/youcookii_val.csv \
|
||||
--youcook_caption_path data/youcookii_caption.pickle \
|
||||
--youcook_features_path_2D data/youcookii_videos_features \
|
||||
--output_dir ckpt_youcook_retrieval --bert_model bert-base-uncased 、
|
||||
--do_lower_case --lr 3e-5 --max_words 48 --max_frames 48 \
|
||||
--batch_size_val 200 --visual_num_hidden_layers 6 \
|
||||
--datatype youcook --init_model ${INIT_MODEL}
|
||||
```
|
||||
|
||||
## Caption
|
||||
```
|
||||
INIT_MODEL=<from second phase>
|
||||
python -m torch.distributed.launch --nproc_per_node=4 train_transcript_distributed.py \
|
||||
--do_train --num_thread_reader=4 \
|
||||
--epochs=5 --batch_size=16 \
|
||||
--n_pair=-1 --n_display=100 \
|
||||
--youcook_train_csv data/youcookii_train.csv \
|
||||
--youcook_val_csv data/youcookii_val.csv \
|
||||
--youcook_caption_path data/youcookii_caption_transcript.pickle \
|
||||
--youcook_features_path_2D data/youcookii_videos_features \
|
||||
--output_dir ckpt_youcook_caption --bert_model bert-base-uncased \
|
||||
--do_lower_case --lr 3e-5 --max_words 128 --max_frames 96 \
|
||||
--batch_size_val 64 --visual_num_hidden_layers 6 \
|
||||
--decoder_num_hidden_layers 3 \
|
||||
--init_model ${INIT_MODEL}
|
||||
```
|
||||
|
||||
# Pretrain on HowTo100M
|
||||
## Phase I
|
||||
```
|
||||
ROOT_PATH=.
|
||||
DATA_PATH=${ROOT_PATH}/data
|
||||
SAVE_PATH=${ROOT_PATH}/models
|
||||
MODEL_PATH=${ROOT_PATH}/UniVL
|
||||
python -m torch.distributed.launch --nproc_per_node=8 ${MODEL_PATH}/train_transcript_distributed.py \
|
||||
--do_pretrain --num_thread_reader=0 --epochs=50 \
|
||||
--batch_size=1920 --n_pair=3 --n_display=100 \
|
||||
--bert_model bert-base-uncased --do_lower_case --lr 1e-4 \
|
||||
--max_words 48 --max_frames 64 --batch_size_val 344 \
|
||||
--output_dir ${SAVE_PATH}/pre_trained/pre_s3d_L48_V6_D3_Phase1 \
|
||||
--features_path_2D ${DATA_PATH}/features \
|
||||
--train_csv ${DATA_PATH}/HowTo100M.csv \
|
||||
--caption_path ${DATA_PATH}/caption.pickle \
|
||||
--visual_num_hidden_layers 6 --gradient_accumulation_steps 16 \
|
||||
--sampled_use_mil --load_checkpoint
|
||||
```
|
||||
|
||||
## Phase II
|
||||
```
|
||||
ROOT_PATH=.
|
||||
DATA_PATH=${ROOT_PATH}/data
|
||||
SAVE_PATH=${ROOT_PATH}/models
|
||||
MODEL_PATH=${ROOT_PATH}/UniVL
|
||||
INIT_MODEL=<from first phase>
|
||||
python -m torch.distributed.launch --nproc_per_node=8 ${MODEL_PATH}/train_transcript_distributed.py \
|
||||
--do_pretrain --num_thread_reader=0 --epochs=50 \
|
||||
--batch_size=960 --n_pair=3 --n_display=100 \
|
||||
--bert_model bert-base-uncased --do_lower_case --lr 1e-4 \
|
||||
--max_words 48 --max_frames 64 --batch_size_val 344 \
|
||||
--output_dir ${SAVE_PATH}/pre_trained/pre_s3d_L48_V6_D3_Phase2 \
|
||||
--features_path_2D ${DATA_PATH}/features \
|
||||
--train_csv ${DATA_PATH}/HowTo100M.csv \
|
||||
--caption_path ${DATA_PATH}/caption.pickle \
|
||||
--visual_num_hidden_layers 6 --decoder_num_hidden_layers 3 \
|
||||
--gradient_accumulation_steps 60 \
|
||||
--cross_model --pretrain_with_joint_sim --sampled_use_mil \
|
||||
--pretrain_enhance_vmodal --pretrain_without_decoder \
|
||||
--load_checkpoint --init_model ${INIT_MODEL}
|
||||
```
|
|
@ -0,0 +1,9 @@
|
|||
# pip install prefetch_generator
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
|
||||
class DataLoaderX(DataLoader):
|
||||
def __iter__(self):
|
||||
# transforms generator into a background-thead generator.
|
||||
return BackgroundGenerator(super().__iter__(), max_prefetch=1)
|
|
@ -0,0 +1,28 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
def compute_metrics(x):
|
||||
sx = np.sort(-x, axis=1)
|
||||
d = np.diag(-x)
|
||||
d = d[:, np.newaxis]
|
||||
ind = sx - d
|
||||
ind = np.where(ind == 0)
|
||||
ind = ind[1]
|
||||
metrics = {}
|
||||
metrics['R1'] = float(np.sum(ind == 0)) / len(ind)
|
||||
metrics['R5'] = float(np.sum(ind < 5)) / len(ind)
|
||||
metrics['R10'] = float(np.sum(ind < 10)) / len(ind)
|
||||
metrics['MR'] = np.median(ind) + 1
|
||||
return metrics
|
||||
|
||||
|
||||
def print_computed_metrics(metrics):
|
||||
r1 = metrics['R1']
|
||||
r5 = metrics['R5']
|
||||
r10 = metrics['R10']
|
||||
mr = metrics['MR']
|
||||
print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr))
|
|
@ -0,0 +1,117 @@
|
|||
""" Manage beam search info structure.
|
||||
|
||||
Heavily borrowed from OpenNMT-py.
|
||||
For code in OpenNMT-py, please check the following link:
|
||||
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class Constants():
|
||||
def __init__(self):
|
||||
self.PAD = 0
|
||||
self.UNK = 1
|
||||
self.BOS = 2
|
||||
self.EOS = 3
|
||||
self.PAD_WORD = '[MASK]'
|
||||
self.UNK_WORD = '[UNK]'
|
||||
self.BOS_WORD = '[CLS]'
|
||||
self.EOS_WORD = '[SEP]'
|
||||
|
||||
@classmethod
|
||||
def from_tokenizer(cls, tokenizer):
|
||||
instance = cls()
|
||||
instance.PAD = tokenizer.vocab[instance.PAD_WORD]
|
||||
instance.UNK = tokenizer.vocab[instance.UNK_WORD]
|
||||
instance.BOS = tokenizer.vocab[instance.BOS_WORD]
|
||||
instance.EOS = tokenizer.vocab[instance.EOS_WORD]
|
||||
return instance
|
||||
|
||||
class Beam():
|
||||
''' Beam search '''
|
||||
|
||||
def __init__(self, size, device=False, tokenizer=None):
|
||||
if tokenizer is None:
|
||||
self.constants = Constants()
|
||||
else:
|
||||
self.constants = Constants.from_tokenizer(tokenizer)
|
||||
|
||||
self.size = size
|
||||
self._done = False
|
||||
# The score for each interface on the beam.
|
||||
self.scores = torch.zeros((size,), dtype=torch.float, device=device)
|
||||
self.all_scores = []
|
||||
|
||||
# The backpointers at each time-step.
|
||||
self.prev_ks = []
|
||||
|
||||
# The outputs at each time-step.
|
||||
self.next_ys = [torch.full((size,), self.constants.BOS, dtype=torch.long, device=device)]
|
||||
|
||||
def get_current_state(self):
|
||||
"Get the outputs for the current timestep."
|
||||
return self.get_tentative_hypothesis()
|
||||
|
||||
def get_current_origin(self):
|
||||
"Get the backpointers for the current timestep."
|
||||
return self.prev_ks[-1]
|
||||
|
||||
@property
|
||||
def done(self):
|
||||
return self._done
|
||||
|
||||
def advance(self, word_prob, word_length=None):
|
||||
|
||||
"Update beam status and check if finished or not."
|
||||
num_words = word_prob.size(1)
|
||||
# Sum the previous scores.
|
||||
if len(self.prev_ks) > 0:
|
||||
beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
|
||||
else:
|
||||
beam_lk = word_prob[0]
|
||||
flat_beam_lk = beam_lk.view(-1)
|
||||
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort
|
||||
self.all_scores.append(self.scores)
|
||||
self.scores = best_scores
|
||||
# bestScoresId is flattened as a (beam x word) array,
|
||||
# so we need to calculate which word and beam each score came from
|
||||
prev_k = best_scores_id / num_words
|
||||
self.prev_ks.append(prev_k)
|
||||
self.next_ys.append(best_scores_id - prev_k * num_words)
|
||||
# End condition is when top-of-beam is EOS.
|
||||
if self.next_ys[-1][0].item() == self.constants.EOS:
|
||||
self._done = True
|
||||
|
||||
return self._done
|
||||
|
||||
def sort_scores(self):
|
||||
"Sort the scores."
|
||||
return torch.sort(self.scores, 0, True)
|
||||
|
||||
def get_the_best_score_and_idx(self):
|
||||
"Get the score of the best in the beam."
|
||||
scores, ids = self.sort_scores()
|
||||
return scores[1], ids[1]
|
||||
|
||||
def get_tentative_hypothesis(self):
|
||||
"Get the decoded sequence for the current timestep."
|
||||
|
||||
if len(self.next_ys) == 1:
|
||||
dec_seq = self.next_ys[0].unsqueeze(1)
|
||||
else:
|
||||
_, keys = self.sort_scores()
|
||||
hyps = [self.get_hypothesis(k) for k in keys]
|
||||
hyps = [[self.constants.BOS] + h for h in hyps]
|
||||
dec_seq = torch.LongTensor(hyps)
|
||||
|
||||
return dec_seq
|
||||
|
||||
def get_hypothesis(self, k):
|
||||
""" Walk back to construct the full hypothesis. """
|
||||
hyp = []
|
||||
for j in range(len(self.prev_ks) - 1, -1, -1):
|
||||
hyp.append(self.next_ys[j+1][k])
|
||||
k = self.prev_ks[j][k]
|
||||
|
||||
return list(map(lambda x: x.item(), hyp[::-1]))
|
|
@ -0,0 +1,711 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
|
||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
|
||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
|
||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
|
||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
||||
}
|
||||
|
||||
CONFIG_NAME = 'bert_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
class BertConfig(object):
|
||||
"""Configuration class to store the configuration of a `BertModel`.
|
||||
"""
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02):
|
||||
"""Constructs BertConfig.
|
||||
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
probabilities.
|
||||
max_position_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`BertModel`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
elif isinstance(vocab_size_or_config_json_file, int):
|
||||
self.vocab_size = vocab_size_or_config_json_file
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
else:
|
||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
||||
config = BertConfig(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `BertConfig` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
||||
except ImportError:
|
||||
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
||||
class BertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(BertLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(BertEmbeddings, self).__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None):
|
||||
seq_length = input_ids.size(1)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
return context_layer
|
||||
|
||||
|
||||
class BertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertSelfOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertAttention, self).__init__()
|
||||
self.self = BertSelfAttention(config)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, attention_mask):
|
||||
self_output = self.self(input_tensor, attention_mask)
|
||||
attention_output = self.output(self_output, input_tensor)
|
||||
return attention_output
|
||||
|
||||
|
||||
class BertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertLayer, self).__init__()
|
||||
self.attention = BertAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertEncoder, self).__init__()
|
||||
layer = BertLayer(config)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
||||
all_encoder_layers = []
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states, attention_mask)
|
||||
if output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
if not output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
return all_encoder_layers
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPooler, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class BertPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPredictionHeadTransform, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLMPredictionHead(nn.Module):
|
||||
def __init__(self, config, bert_model_embedding_weights):
|
||||
super(BertLMPredictionHead, self).__init__()
|
||||
self.transform = BertPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
||||
bert_model_embedding_weights.size(0),
|
||||
bias=False)
|
||||
self.decoder.weight = bert_model_embedding_weights
|
||||
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states) + self.bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config, bert_model_embedding_weights):
|
||||
super(BertOnlyMLMHead, self).__init__()
|
||||
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
|
||||
|
||||
def forward(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class BertOnlyNSPHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertOnlyNSPHead, self).__init__()
|
||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def forward(self, pooled_output):
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return seq_relationship_score
|
||||
|
||||
|
||||
class BertPreTrainingHeads(nn.Module):
|
||||
def __init__(self, config, bert_model_embedding_weights):
|
||||
super(BertPreTrainingHeads, self).__init__()
|
||||
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
|
||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def forward(self, sequence_output, pooled_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
class PreTrainedBertModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedBertModel, self).__init__()
|
||||
if not isinstance(config, BertConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
def init_bert_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, BertLayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
model = self.module if hasattr(self, 'module') else self
|
||||
base_model = getattr(model, "bert", model) # get the base model if needed
|
||||
|
||||
old_embeddings = base_model.embeddings.word_embeddings
|
||||
# get_resized_embeddings <-----
|
||||
if new_num_tokens is None:
|
||||
return old_embeddings
|
||||
|
||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
if old_num_tokens == new_num_tokens:
|
||||
return old_embeddings
|
||||
|
||||
# Build new embeddings
|
||||
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
||||
device = old_embeddings.weight.device
|
||||
new_embeddings.to(device)
|
||||
|
||||
# initialize all new embeddings (in particular added tokens)
|
||||
self.init_bert_weights(new_embeddings)
|
||||
|
||||
# Copy word embeddings from the previous weights
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
||||
# ----->
|
||||
|
||||
# Update base model and current model config
|
||||
base_model.embeddings.word_embeddings = new_embeddings
|
||||
base_model.config.vocab_size = new_num_tokens
|
||||
model.config.vocab_size = new_num_tokens
|
||||
model.cls = BertOnlyMLMHead(model.config, base_model.embeddings.word_embeddings.weight).to(device)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
|
||||
'''
|
||||
abstract from `from_pretrained`
|
||||
'''
|
||||
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
|
||||
if os.path.exists(archive_file) is False:
|
||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
||||
else:
|
||||
archive_file = pretrained_model_name
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list. "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
archive_file))
|
||||
return None
|
||||
if resolved_archive_file == archive_file:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {}".format(archive_file))
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
tempdir = None
|
||||
if os.path.isdir(resolved_archive_file):
|
||||
serialization_dir = resolved_archive_file
|
||||
else:
|
||||
# Extract archive to temp dir
|
||||
tempdir = tempfile.mkdtemp()
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("extracting archive file {} to temp dir {}".format(
|
||||
resolved_archive_file, tempdir))
|
||||
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||
archive.extractall(tempdir)
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
config = BertConfig.from_json_file(config_file)
|
||||
config.type_vocab_size = type_vocab_size
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Model config {}".format(config))
|
||||
|
||||
if state_dict is None:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_path):
|
||||
state_dict = torch.load(weights_path)
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Weight doesn't exsits. {}".format(weights_path))
|
||||
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
return config, state_dict
|
||||
|
||||
@classmethod
|
||||
def init_preweight(cls, model, state_dict):
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None,
|
||||
cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `bert-base-uncased`
|
||||
. `bert-large-uncased`
|
||||
. `bert-base-cased`
|
||||
. `bert-large-cased`
|
||||
. `bert-base-multilingual-uncased`
|
||||
. `bert-base-multilingual-cased`
|
||||
. `bert-base-chinese`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
(ex: num_classes for BertForSequenceClassification)
|
||||
"""
|
||||
config, state_dict = PreTrainedBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict)
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
return model
|
||||
model = PreTrainedBertModel.init_preweight(model, state_dict)
|
||||
|
||||
return model
|
||||
|
||||
class BertModel(PreTrainedBertModel):
|
||||
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
|
||||
|
||||
Params:
|
||||
config: a BertConfig class instance with the configuration to build a new model
|
||||
|
||||
Inputs:
|
||||
`type`: a str, indicates which masking will be used in the attention, choice from [`bi`, `seq`, `gen`]
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see BERT paper for more details).
|
||||
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||
a batch has varying length sentences.
|
||||
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
||||
|
||||
Outputs: Tuple of (encoded_layers, pooled_output)
|
||||
`encoded_layers`: controled by `output_all_encoded_layers` argument:
|
||||
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
|
||||
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
|
||||
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
||||
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
||||
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
||||
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
||||
classifier pretrained on top of the hidden state associated to the first character of the
|
||||
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = modeling.BertModel(config=config)
|
||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(BertModel, self).__init__(config)
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = BertEncoder(config)
|
||||
self.pooler = BertPooler(config)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
|
||||
def forward(self, type, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True,
|
||||
attention_sentenceA_mask=None,
|
||||
attention_sentenceB_mask=None):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||
encoded_layers = self.encoder(embedding_output,
|
||||
extended_attention_mask,
|
||||
output_all_encoded_layers=output_all_encoded_layers)
|
||||
sequence_output = encoded_layers[-1]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
if not output_all_encoded_layers:
|
||||
encoded_layers = encoded_layers[-1]
|
||||
return encoded_layers, pooled_output
|
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 768,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"max_position_embeddings": 1024,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 2,
|
||||
"vocab_size": 768
|
||||
}
|
|
@ -0,0 +1,624 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'cross-bert-base': "[TODO]",
|
||||
'cross-bert-large': "[TODO]",
|
||||
}
|
||||
|
||||
CONFIG_NAME = 'cross_bert_config.json'
|
||||
WEIGHTS_NAME = 'cross_pytorch_model.bin'
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
class CrossBertConfig(object):
|
||||
"""Configuration class to store the configuration of a `CrossBertModel`.
|
||||
"""
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02):
|
||||
"""Constructs CrossBertConfig.
|
||||
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossBertModel`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
probabilities.
|
||||
max_position_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`CrossBertModel`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
elif isinstance(vocab_size_or_config_json_file, int):
|
||||
self.vocab_size = vocab_size_or_config_json_file
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
else:
|
||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `CrossBertConfig` from a Python dictionary of parameters."""
|
||||
config = CrossBertConfig(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `CrossBertConfig` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as CrossBertLayerNorm
|
||||
except ImportError:
|
||||
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
||||
class CrossBertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(CrossBertLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
class CrossBertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CrossBertEmbeddings, self).__init__()
|
||||
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, concat_embeddings, concat_type=None):
|
||||
|
||||
batch_size, seq_length = concat_embeddings.size(0), concat_embeddings.size(1)
|
||||
if concat_type is None:
|
||||
concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device)
|
||||
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=concat_embeddings.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(concat_embeddings.size(0), -1)
|
||||
|
||||
token_type_embeddings = self.token_type_embeddings(concat_type)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
embeddings = concat_embeddings + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
class CrossBertSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in CrossBertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
return context_layer
|
||||
|
||||
|
||||
class CrossBertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertSelfOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossBertAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertAttention, self).__init__()
|
||||
self.self = CrossBertSelfAttention(config)
|
||||
self.output = CrossBertSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, attention_mask):
|
||||
self_output = self.self(input_tensor, attention_mask)
|
||||
attention_output = self.output(self_output, input_tensor)
|
||||
return attention_output
|
||||
|
||||
|
||||
class CrossBertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossBertOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossBertLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertLayer, self).__init__()
|
||||
self.attention = CrossBertAttention(config)
|
||||
self.intermediate = CrossBertIntermediate(config)
|
||||
self.output = CrossBertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class CrossBertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertEncoder, self).__init__()
|
||||
layer = CrossBertLayer(config)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
||||
all_encoder_layers = []
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states, attention_mask)
|
||||
if output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
if not output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
return all_encoder_layers
|
||||
|
||||
|
||||
class CrossBertPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertPooler, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class CrossBertPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertPredictionHeadTransform, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossBertLMPredictionHead(nn.Module):
|
||||
def __init__(self, config, cross_bert_model_embedding_weights):
|
||||
super(CrossBertLMPredictionHead, self).__init__()
|
||||
self.transform = CrossBertPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(cross_bert_model_embedding_weights.size(1),
|
||||
cross_bert_model_embedding_weights.size(0),
|
||||
bias=False)
|
||||
self.decoder.weight = cross_bert_model_embedding_weights
|
||||
self.bias = nn.Parameter(torch.zeros(cross_bert_model_embedding_weights.size(0)))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states) + self.bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossBertOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config, cross_bert_model_embedding_weights):
|
||||
super(CrossBertOnlyMLMHead, self).__init__()
|
||||
self.predictions = CrossBertLMPredictionHead(config, cross_bert_model_embedding_weights)
|
||||
|
||||
def forward(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class CrossBertOnlyNSPHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertOnlyNSPHead, self).__init__()
|
||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def forward(self, pooled_output):
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return seq_relationship_score
|
||||
|
||||
|
||||
class CrossBertPreTrainingHeads(nn.Module):
|
||||
def __init__(self, config, cross_bert_model_embedding_weights):
|
||||
super(CrossBertPreTrainingHeads, self).__init__()
|
||||
self.predictions = CrossBertLMPredictionHead(config, cross_bert_model_embedding_weights)
|
||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def forward(self, sequence_output, pooled_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
class PreTrainedCrossBertModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedCrossBertModel, self).__init__()
|
||||
if not isinstance(config, CrossBertConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `CrossBertConfig`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
def init_cross_bert_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, CrossBertLayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
|
||||
'''
|
||||
abstract from `from_pretrained`
|
||||
'''
|
||||
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
|
||||
if os.path.exists(archive_file) is False:
|
||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
||||
else:
|
||||
archive_file = pretrained_model_name
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list. "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
archive_file))
|
||||
return None
|
||||
if resolved_archive_file == archive_file:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {}".format(archive_file))
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
tempdir = None
|
||||
if os.path.isdir(resolved_archive_file):
|
||||
serialization_dir = resolved_archive_file
|
||||
else:
|
||||
# Extract archive to temp dir
|
||||
tempdir = tempfile.mkdtemp()
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("extracting archive file {} to temp dir {}".format(
|
||||
resolved_archive_file, tempdir))
|
||||
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||
archive.extractall(tempdir)
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
config = CrossBertConfig.from_json_file(config_file)
|
||||
config.type_vocab_size = type_vocab_size
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Model config {}".format(config))
|
||||
|
||||
if state_dict is None:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_path):
|
||||
state_dict = torch.load(weights_path)
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Weight doesn't exsits. {}".format(weights_path))
|
||||
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
return config, state_dict
|
||||
|
||||
@classmethod
|
||||
def init_preweight(cls, model, state_dict):
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(model, prefix='' if hasattr(model, 'cross_bert') else 'cross_bert.')
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None,
|
||||
cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedCrossBertModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `cross_bert-base`
|
||||
. `cross_bert-large`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `cross-bert_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a CrossBertForPreTraining instance
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific CrossBert class
|
||||
(ex: num_classes for CrossBertForSequenceClassification)
|
||||
"""
|
||||
config, state_dict = PreTrainedCrossBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict)
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
return model
|
||||
model = PreTrainedCrossBertModel.init_preweight(model, state_dict)
|
||||
|
||||
return model
|
||||
|
||||
class CrossBertModel(PreTrainedCrossBertModel):
|
||||
def __init__(self, config):
|
||||
super(CrossBertModel, self).__init__(config)
|
||||
self.embeddings = CrossBertEmbeddings(config)
|
||||
self.encoder = CrossBertEncoder(config)
|
||||
self.pooler = CrossBertPooler(config)
|
||||
self.apply(self.init_cross_bert_weights)
|
||||
|
||||
def forward(self, type, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True,
|
||||
attention_sentenceA_mask=None,
|
||||
attention_sentenceB_mask=None):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(concat_input.size(0), concat_input.size(1))
|
||||
if concat_type is None:
|
||||
concat_type = torch.zeros_like(attention_mask)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(concat_input, concat_type)
|
||||
encoded_layers = self.encoder(embedding_output,
|
||||
extended_attention_mask,
|
||||
output_all_encoded_layers=output_all_encoded_layers)
|
||||
sequence_output = encoded_layers[-1]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
if not output_all_encoded_layers:
|
||||
encoded_layers = encoded_layers[-1]
|
||||
return encoded_layers, pooled_output
|
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 768,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"type_vocab_size": 2,
|
||||
"vocab_size": 30522,
|
||||
"num_decoder_layers": 1,
|
||||
"max_target_embeddings": 512
|
||||
}
|
|
@ -0,0 +1,713 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'decoder-bert-base': "[TODO]",
|
||||
'decoder-bert-large': "[TODO]",
|
||||
}
|
||||
|
||||
CONFIG_NAME = 'decoder_bert_config.json'
|
||||
WEIGHTS_NAME = 'decoder_pytorch_model.bin'
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
class BertConfig(object):
|
||||
"""Configuration class to store the configuration of a `BertModel`.
|
||||
"""
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
max_target_embeddings=128,
|
||||
num_decoder_layers=1):
|
||||
"""Constructs BertConfig.
|
||||
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
probabilities.
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`BertModel`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
max_target_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
num_decoder_layers:
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
elif isinstance(vocab_size_or_config_json_file, int):
|
||||
self.vocab_size = vocab_size_or_config_json_file
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.max_target_embeddings = max_target_embeddings
|
||||
self.num_decoder_layers = num_decoder_layers
|
||||
else:
|
||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
||||
config = BertConfig(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `BertConfig` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as DecoderBertLayerNorm
|
||||
except ImportError:
|
||||
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
||||
class DecoderBertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(DecoderBertLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
class BertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertSelfOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
class BertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPredictionHeadTransform, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLMPredictionHead(nn.Module):
|
||||
def __init__(self, config, bert_model_embedding_weights):
|
||||
super(BertLMPredictionHead, self).__init__()
|
||||
self.transform = BertPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
||||
bert_model_embedding_weights.size(0),
|
||||
bias=False)
|
||||
self.decoder.weight = bert_model_embedding_weights
|
||||
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states) + self.bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config, bert_model_embedding_weights):
|
||||
super(BertOnlyMLMHead, self).__init__()
|
||||
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
|
||||
|
||||
def forward(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class PreTrainedDecoderBertModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedDecoderBertModel, self).__init__()
|
||||
if not isinstance(config, BertConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
def init_bert_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, DecoderBertLayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
model = self.module if hasattr(self, 'module') else self
|
||||
base_model = getattr(model, "bert", model) # get the base model if needed
|
||||
|
||||
old_embeddings = base_model.embeddings.word_embeddings
|
||||
# get_resized_embeddings <-----
|
||||
if new_num_tokens is None:
|
||||
return old_embeddings
|
||||
|
||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
if old_num_tokens == new_num_tokens:
|
||||
return old_embeddings
|
||||
|
||||
# Build new embeddings
|
||||
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
||||
device = old_embeddings.weight.device
|
||||
new_embeddings.to(device)
|
||||
|
||||
# initialize all new embeddings (in particular added tokens)
|
||||
self.init_bert_weights(new_embeddings)
|
||||
|
||||
# Copy word embeddings from the previous weights
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
||||
# ----->
|
||||
|
||||
# Update base model and current model config
|
||||
base_model.embeddings.word_embeddings = new_embeddings
|
||||
base_model.config.vocab_size = new_num_tokens
|
||||
model.config.vocab_size = new_num_tokens
|
||||
model.cls = BertOnlyMLMHead(model.config, base_model.embeddings.word_embeddings.weight).to(device)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
|
||||
'''
|
||||
abstract from `from_pretrained`
|
||||
'''
|
||||
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
|
||||
if os.path.exists(archive_file) is False:
|
||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
||||
else:
|
||||
archive_file = pretrained_model_name
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list. "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
archive_file))
|
||||
return None
|
||||
if resolved_archive_file == archive_file:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {}".format(archive_file))
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
tempdir = None
|
||||
if os.path.isdir(resolved_archive_file):
|
||||
serialization_dir = resolved_archive_file
|
||||
else:
|
||||
# Extract archive to temp dir
|
||||
tempdir = tempfile.mkdtemp()
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("extracting archive file {} to temp dir {}".format(
|
||||
resolved_archive_file, tempdir))
|
||||
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||
archive.extractall(tempdir)
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
config = BertConfig.from_json_file(config_file)
|
||||
config.type_vocab_size = type_vocab_size
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Model config {}".format(config))
|
||||
|
||||
if state_dict is None:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_path):
|
||||
state_dict = torch.load(weights_path)
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Weight doesn't exsits. {}".format(weights_path))
|
||||
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
return config, state_dict
|
||||
|
||||
@classmethod
|
||||
def init_preweight(cls, model, state_dict):
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None,
|
||||
cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedDecoderBertModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `bert-base-uncased`
|
||||
. `bert-large-uncased`
|
||||
. `bert-base-cased`
|
||||
. `bert-large-cased`
|
||||
. `bert-base-multilingual-uncased`
|
||||
. `bert-base-multilingual-cased`
|
||||
. `bert-base-chinese`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
(ex: num_classes for BertForSequenceClassification)
|
||||
"""
|
||||
config, state_dict = PreTrainedDecoderBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict)
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
return model
|
||||
model = PreTrainedDecoderBertModel.init_preweight(model, state_dict)
|
||||
|
||||
return model
|
||||
|
||||
# import torch.nn.functional as F
|
||||
class ScaledDotProductAttention(nn.Module):
|
||||
''' Scaled Dot-Product Attention '''
|
||||
|
||||
def __init__(self, temperature, attn_dropout=0.1):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.dropout = nn.Dropout(attn_dropout)
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
|
||||
attn = torch.bmm(q, k.transpose(1, 2))
|
||||
attn = attn / self.temperature
|
||||
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask, -np.inf)
|
||||
|
||||
attn = self.softmax(attn)
|
||||
attn = self.dropout(attn)
|
||||
output = torch.bmm(attn, v)
|
||||
|
||||
return output, attn
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
''' Multi-Head Attention module '''
|
||||
|
||||
def __init__(self, config):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, q, k, v, attention_mask):
|
||||
mixed_query_layer = self.query(q)
|
||||
mixed_key_layer = self.key(k)
|
||||
mixed_value_layer = self.value(v)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
return context_layer, attention_scores
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
''' A two-feed-forward-layer module '''
|
||||
|
||||
def __init__(self, d_in, d_hid, dropout=0.1):
|
||||
super().__init__()
|
||||
self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
|
||||
self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
|
||||
self.layer_norm = nn.LayerNorm(d_in)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
output = x.transpose(1, 2)
|
||||
output = self.w_2(gelu(self.w_1(output)))
|
||||
output = output.transpose(1, 2)
|
||||
output = self.dropout(output)
|
||||
output = self.layer_norm(output + residual)
|
||||
return output
|
||||
|
||||
class DecoderAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(DecoderAttention, self).__init__()
|
||||
self.att = MultiHeadAttention(config)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, q, k, v, attention_mask):
|
||||
att_output, attention_probs = self.att(q, k, v, attention_mask)
|
||||
attention_output = self.output(att_output, q)
|
||||
return attention_output, attention_probs
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(DecoderLayer, self).__init__()
|
||||
self.slf_attn = DecoderAttention(config)
|
||||
self.enc_attn = DecoderAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None):
|
||||
slf_output, _ = self.slf_attn(dec_input, dec_input, dec_input, slf_attn_mask)
|
||||
dec_output, dec_att_scores = self.enc_attn(slf_output, enc_output, enc_output, dec_enc_attn_mask)
|
||||
intermediate_output = self.intermediate(dec_output)
|
||||
dec_output = self.output(intermediate_output, dec_output)
|
||||
return dec_output, dec_att_scores
|
||||
|
||||
class BertDecoderEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
def __init__(self, config, bert_word_embeddings_weight, bert_position_embeddings_weight):
|
||||
super(BertDecoderEmbeddings, self).__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.position_embeddings = nn.Embedding(config.max_target_embeddings, config.hidden_size)
|
||||
self.word_embeddings.weight = bert_word_embeddings_weight
|
||||
self.position_embeddings.weight = bert_position_embeddings_weight
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids):
|
||||
seq_length = input_ids.size(1)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
class BertDecoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertDecoder, self).__init__()
|
||||
layer = DecoderLayer(config)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_decoder_layers)])
|
||||
|
||||
def forward(self, hidden_states, encoder_outs, self_attn_mask, attention_mask, output_all_encoded_layers=False):
|
||||
dec_att_scores = None
|
||||
all_encoder_layers = []
|
||||
all_dec_att_probs = []
|
||||
for layer_module in self.layer:
|
||||
hidden_states, dec_att_scores = layer_module(hidden_states, encoder_outs, self_attn_mask, attention_mask)
|
||||
if output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
all_dec_att_probs.append(dec_att_scores)
|
||||
if not output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
all_dec_att_probs.append(dec_att_scores)
|
||||
return all_encoder_layers, all_dec_att_probs
|
||||
|
||||
class BertDecoderClassifier(nn.Module):
|
||||
def __init__(self, config, embedding_weights):
|
||||
super(BertDecoderClassifier, self).__init__()
|
||||
self.cls = BertOnlyMLMHead(config, embedding_weights)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
cls_scores = self.cls(hidden_states)
|
||||
return cls_scores
|
||||
|
||||
class DecoderBertModel(nn.Module):
|
||||
|
||||
def init_decoder_bert_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, DecoderBertLayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
"""
|
||||
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
||||
is a :class:`TransformerDecoderLayer`.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
final_norm (bool, optional): apply layer norm to the output of the
|
||||
final decoder layer (default: True).
|
||||
"""
|
||||
|
||||
def __init__(self, config, bert_word_embeddings_weight, bert_position_embeddings_weight):
|
||||
super(DecoderBertModel, self).__init__()
|
||||
self.config = config
|
||||
self.max_target_length = config.max_target_embeddings
|
||||
self.embeddings = BertDecoderEmbeddings(config, bert_word_embeddings_weight, bert_position_embeddings_weight)
|
||||
self.decoder = BertDecoder(config)
|
||||
self.classifier = BertDecoderClassifier(config, bert_word_embeddings_weight)
|
||||
self.apply(self.init_decoder_bert_weights)
|
||||
|
||||
def forward(self, input_ids, encoder_input_ids, encoder_outs=None, answer_mask=None, encoder_mask=None):
|
||||
"""
|
||||
Args:
|
||||
input_ids (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing
|
||||
encoder_outs (Tensor, optional): output from the encoder, used for encoder-side attention
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- the last decoder layer's output of shape `(batch, tgt_len, vocab)`
|
||||
- the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)`
|
||||
"""
|
||||
embedding_output = self.embeddings(input_ids)
|
||||
|
||||
extended_encoder_mask = encoder_mask.unsqueeze(1).unsqueeze(2) # b x 1 x 1 x ls
|
||||
extended_encoder_mask = extended_encoder_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_encoder_mask = (1.0 - extended_encoder_mask) * -10000.0
|
||||
|
||||
extended_answer_mask = answer_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_answer_mask = extended_answer_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
|
||||
sz_b, len_s, _ = embedding_output.size()
|
||||
subsequent_mask = torch.triu(torch.ones((len_s, len_s), device=embedding_output.device, dtype=embedding_output.dtype), diagonal=1)
|
||||
self_attn_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1).unsqueeze(1) # b x 1 x ls x ls
|
||||
slf_attn_mask = ((1.0 - extended_answer_mask) + self_attn_mask).gt(0).to(dtype=next(self.parameters()).dtype)
|
||||
self_attn_mask = slf_attn_mask * -10000.0
|
||||
|
||||
decoded_layers, dec_att_scores = self.decoder(embedding_output,
|
||||
encoder_outs,
|
||||
self_attn_mask,
|
||||
extended_encoder_mask,
|
||||
)
|
||||
sequence_output = decoded_layers[-1]
|
||||
cls_scores = self.classifier(sequence_output)
|
||||
|
||||
return cls_scores
|
|
@ -0,0 +1,239 @@
|
|||
"""
|
||||
Utilities for working with the local dataset cache.
|
||||
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
||||
Copyright by the AllenNLP authors.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union, IO, Callable, Set
|
||||
from hashlib import sha256
|
||||
from functools import wraps
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||
Path.home() / '.pytorch_pretrained_bert'))
|
||||
|
||||
|
||||
def url_to_filename(url: str, etag: str = None) -> str:
|
||||
"""
|
||||
Convert `url` into a hashed filename in a repeatable way.
|
||||
If `etag` is specified, append its hash to the url's, delimited
|
||||
by a period.
|
||||
"""
|
||||
url_bytes = url.encode('utf-8')
|
||||
url_hash = sha256(url_bytes)
|
||||
filename = url_hash.hexdigest()
|
||||
|
||||
if etag:
|
||||
etag_bytes = etag.encode('utf-8')
|
||||
etag_hash = sha256(etag_bytes)
|
||||
filename += '.' + etag_hash.hexdigest()
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
|
||||
"""
|
||||
Return the url and etag (which may be ``None``) stored for `filename`.
|
||||
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
if not os.path.exists(cache_path):
|
||||
raise FileNotFoundError("file {} not found".format(cache_path))
|
||||
|
||||
meta_path = cache_path + '.json'
|
||||
if not os.path.exists(meta_path):
|
||||
raise FileNotFoundError("file {} not found".format(meta_path))
|
||||
|
||||
with open(meta_path) as meta_file:
|
||||
metadata = json.load(meta_file)
|
||||
url = metadata['url']
|
||||
etag = metadata['etag']
|
||||
|
||||
return url, etag
|
||||
|
||||
|
||||
def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
|
||||
"""
|
||||
Given something that might be a URL (or might be a local path),
|
||||
determine which. If it's a URL, download the file and cache it, and
|
||||
return the path to the cached file. If it's already a local path,
|
||||
make sure the file exists and then return the path.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
||||
if isinstance(url_or_filename, Path):
|
||||
url_or_filename = str(url_or_filename)
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
parsed = urlparse(url_or_filename)
|
||||
|
||||
if parsed.scheme in ('http', 'https', 's3'):
|
||||
# URL, so get it from the cache (downloading if necessary)
|
||||
return get_from_cache(url_or_filename, cache_dir)
|
||||
elif os.path.exists(url_or_filename):
|
||||
# File, and it exists.
|
||||
return url_or_filename
|
||||
elif parsed.scheme == '':
|
||||
# File, but it doesn't exist.
|
||||
raise FileNotFoundError("file {} not found".format(url_or_filename))
|
||||
else:
|
||||
# Something unknown
|
||||
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
||||
|
||||
|
||||
def split_s3_path(url: str) -> Tuple[str, str]:
|
||||
"""Split a full s3 path into the bucket name and path."""
|
||||
parsed = urlparse(url)
|
||||
if not parsed.netloc or not parsed.path:
|
||||
raise ValueError("bad s3 path {}".format(url))
|
||||
bucket_name = parsed.netloc
|
||||
s3_path = parsed.path
|
||||
# Remove '/' at beginning of path.
|
||||
if s3_path.startswith("/"):
|
||||
s3_path = s3_path[1:]
|
||||
return bucket_name, s3_path
|
||||
|
||||
|
||||
def s3_request(func: Callable):
|
||||
"""
|
||||
Wrapper function for s3 requests in order to create more helpful error
|
||||
messages.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(url: str, *args, **kwargs):
|
||||
try:
|
||||
return func(url, *args, **kwargs)
|
||||
except ClientError as exc:
|
||||
if int(exc.response["Error"]["Code"]) == 404:
|
||||
raise FileNotFoundError("file {} not found".format(url))
|
||||
else:
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@s3_request
|
||||
def s3_etag(url: str) -> Optional[str]:
|
||||
"""Check ETag on S3 object."""
|
||||
s3_resource = boto3.resource("s3")
|
||||
bucket_name, s3_path = split_s3_path(url)
|
||||
s3_object = s3_resource.Object(bucket_name, s3_path)
|
||||
return s3_object.e_tag
|
||||
|
||||
|
||||
@s3_request
|
||||
def s3_get(url: str, temp_file: IO) -> None:
|
||||
"""Pull a file directly from S3."""
|
||||
s3_resource = boto3.resource("s3")
|
||||
bucket_name, s3_path = split_s3_path(url)
|
||||
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
||||
|
||||
|
||||
def http_get(url: str, temp_file: IO) -> None:
|
||||
req = requests.get(url, stream=True)
|
||||
content_length = req.headers.get('Content-Length')
|
||||
total = int(content_length) if content_length is not None else None
|
||||
progress = tqdm(unit="B", total=total)
|
||||
for chunk in req.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
progress.close()
|
||||
|
||||
|
||||
def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
|
||||
"""
|
||||
Given a URL, look for the corresponding dataset in the local cache.
|
||||
If it's not there, download it. Then return the path to the cached file.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Get eTag to add to filename, if it exists.
|
||||
if url.startswith("s3://"):
|
||||
etag = s3_etag(url)
|
||||
else:
|
||||
response = requests.head(url, allow_redirects=True)
|
||||
if response.status_code != 200:
|
||||
raise IOError("HEAD request failed for url {} with status code {}"
|
||||
.format(url, response.status_code))
|
||||
etag = response.headers.get("ETag")
|
||||
|
||||
filename = url_to_filename(url, etag)
|
||||
|
||||
# get cache path to put the file
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
|
||||
if not os.path.exists(cache_path):
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
with tempfile.NamedTemporaryFile() as temp_file:
|
||||
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
|
||||
|
||||
# GET file object
|
||||
if url.startswith("s3://"):
|
||||
s3_get(url, temp_file)
|
||||
else:
|
||||
http_get(url, temp_file)
|
||||
|
||||
# we are copying the file before closing it, so flush to avoid truncation
|
||||
temp_file.flush()
|
||||
# shutil.copyfileobj() starts at the current position, so go to the start
|
||||
temp_file.seek(0)
|
||||
|
||||
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
|
||||
with open(cache_path, 'wb') as cache_file:
|
||||
shutil.copyfileobj(temp_file, cache_file)
|
||||
|
||||
logger.info("creating metadata file for %s", cache_path)
|
||||
meta = {'url': url, 'etag': etag}
|
||||
meta_path = cache_path + '.json'
|
||||
with open(meta_path, 'w') as meta_file:
|
||||
json.dump(meta, meta_file)
|
||||
|
||||
logger.info("removing temp file %s", temp_file.name)
|
||||
|
||||
return cache_path
|
||||
|
||||
|
||||
def read_set_from_file(filename: str) -> Set[str]:
|
||||
'''
|
||||
Extract a de-duped collection (set) of text from a file.
|
||||
Expected file format is one item per line.
|
||||
'''
|
||||
collection = set()
|
||||
with open(filename, 'r', encoding='utf-8') as file_:
|
||||
for line in file_:
|
||||
collection.add(line.rstrip())
|
||||
return collection
|
||||
|
||||
|
||||
def get_file_extension(path: str, dot=True, lower: bool = True):
|
||||
ext = os.path.splitext(path)[1]
|
||||
ext = ext if dot else ext[1:]
|
||||
return ext.lower() if lower else ext
|
|
@ -0,0 +1,634 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from pytorch_pretrained_bert.bert_module import PreTrainedBertModel, BertModel, BertLayerNorm, BertOnlyMLMHead
|
||||
from pytorch_pretrained_bert.visual_module import PreTrainedVisualBertModel, VisualBertModel, VisualBertLayerNorm, VisualBertOnlyMLMHead
|
||||
from pytorch_pretrained_bert.cross_module import PreTrainedCrossBertModel, CrossBertModel, CrossBertLayerNorm
|
||||
from pytorch_pretrained_bert.decoder_module import PreTrainedDecoderBertModel, DecoderBertModel, DecoderBertLayerNorm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
|
||||
except ImportError:
|
||||
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(LayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
class PreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, bert_config, visual_config, cross_config, decoder_config, *inputs, **kwargs):
|
||||
super(PreTrainedModel, self).__init__()
|
||||
self.bert_config = bert_config
|
||||
self.visual_config = visual_config
|
||||
self.cross_config = cross_config
|
||||
self.decoder_config = decoder_config
|
||||
|
||||
self.bert = None
|
||||
self.visual = None
|
||||
self.cross = None
|
||||
self.decoder = None
|
||||
|
||||
def init_bert_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
|
||||
elif isinstance(module, BertLayerNorm) \
|
||||
or isinstance(module, LayerNorm) \
|
||||
or isinstance(module, VisualBertLayerNorm) \
|
||||
or isinstance(module, CrossBertLayerNorm) \
|
||||
or isinstance(module, DecoderBertLayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
raise ValueError("`TODO: Just bert has this function, call it at here.`")
|
||||
|
||||
@classmethod
|
||||
def init_preweight(cls, model, state_dict, prefix=None, task_config=None):
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
if prefix is not None:
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
old_keys.append(key)
|
||||
new_keys.append(prefix+key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(model, prefix='')
|
||||
|
||||
if prefix is None and (task_config is None or task_config.local_rank == 0):
|
||||
logger.info("-"*20)
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, pretrained_visual_model_name, pretrained_cross_model_name,
|
||||
pretrained_decoder_model_name,
|
||||
state_dict=None, cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
|
||||
|
||||
task_config = None
|
||||
if "task_config" in kwargs.keys():
|
||||
task_config = kwargs["task_config"]
|
||||
if not hasattr(task_config, "local_rank"):
|
||||
task_config.__dict__["local_rank"] = 0
|
||||
elif task_config.local_rank == -1:
|
||||
task_config.local_rank = 0
|
||||
|
||||
bert_config, state_dict = PreTrainedBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=task_config)
|
||||
# visual_state_dict is no use here
|
||||
visual_config, _ = PreTrainedVisualBertModel.get_config(pretrained_visual_model_name, cache_dir, state_dict=None, task_config=task_config)
|
||||
# visual_state_dict is no use here
|
||||
cross_config, _ = PreTrainedCrossBertModel.get_config(pretrained_cross_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
|
||||
decoder_config, _ = PreTrainedDecoderBertModel.get_config(pretrained_decoder_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
|
||||
# Instantiate model.
|
||||
model = cls(bert_config, visual_config, cross_config, decoder_config, *inputs, **kwargs)
|
||||
|
||||
assert model.bert is not None
|
||||
assert model.visual is not None
|
||||
# assert model.cross is not None
|
||||
# assert model.decoder is not None
|
||||
|
||||
if state_dict is not None:
|
||||
model = cls.init_preweight(model, state_dict, task_config=task_config)
|
||||
|
||||
return model
|
||||
|
||||
class CrossEn(nn.Module):
|
||||
def __init__(self,):
|
||||
super(CrossEn, self).__init__()
|
||||
|
||||
def forward(self, sim_matrix):
|
||||
logpt = F.log_softmax(sim_matrix, dim=-1)
|
||||
logpt = torch.diag(logpt)
|
||||
nce_loss = -logpt
|
||||
sim_loss = nce_loss.mean()
|
||||
return sim_loss
|
||||
|
||||
class MILNCELoss(nn.Module):
|
||||
def __init__(self, batch_size=1, n_pair=1,):
|
||||
super(MILNCELoss, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.n_pair = n_pair
|
||||
|
||||
def forward(self, sim_matrix):
|
||||
mm_mask = np.eye(self.batch_size)
|
||||
mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair)))
|
||||
mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device)
|
||||
|
||||
from_text_matrix = sim_matrix + mm_mask * -1e12
|
||||
from_video_matrix = sim_matrix.transpose(1, 0) # video * text
|
||||
|
||||
new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1)
|
||||
logpt = F.log_softmax(new_sim_matrix, dim=-1)
|
||||
|
||||
mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1)
|
||||
masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12
|
||||
|
||||
new_logpt = -torch.logsumexp(masked_logpt, dim=-1)
|
||||
|
||||
logpt_choice = torch.zeros_like(new_logpt)
|
||||
mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2)
|
||||
logpt_choice[mark_ind] = 1
|
||||
sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=torch.uint8)).mean()
|
||||
return sim_loss
|
||||
|
||||
class MaxMarginRankingLoss(nn.Module):
|
||||
def __init__(self,
|
||||
margin=1.0,
|
||||
negative_weighting=False,
|
||||
batch_size=1,
|
||||
n_pair=1,
|
||||
hard_negative_rate=0.5,
|
||||
):
|
||||
super(MaxMarginRankingLoss, self).__init__()
|
||||
self.margin = margin
|
||||
self.n_pair = n_pair
|
||||
self.batch_size = batch_size
|
||||
easy_negative_rate = 1 - hard_negative_rate
|
||||
self.easy_negative_rate = easy_negative_rate
|
||||
self.negative_weighting = negative_weighting
|
||||
if n_pair > 1 and batch_size > 1:
|
||||
alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate))
|
||||
mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha
|
||||
mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair)))
|
||||
mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate))
|
||||
self.mm_mask = mm_mask.float()
|
||||
|
||||
def forward(self, x):
|
||||
d = torch.diag(x)
|
||||
max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \
|
||||
F.relu(self.margin + x - d.view(1, -1))
|
||||
if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1:
|
||||
max_margin = max_margin * self.mm_mask.to(max_margin.device)
|
||||
return max_margin.mean()
|
||||
|
||||
class NormalizeVideo(nn.Module):
|
||||
def __init__(self, task_config):
|
||||
super(NormalizeVideo, self).__init__()
|
||||
self.visual_norm2d = LayerNorm(task_config.video_dim)
|
||||
|
||||
def forward(self, video):
|
||||
video = torch.as_tensor(video).float()
|
||||
video = video.view(-1, video.shape[-2], video.shape[-1])
|
||||
video = self.visual_norm2d(video)
|
||||
return video
|
||||
|
||||
def show_log(task_config, info):
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.warning(info)
|
||||
|
||||
def update_attr(target_name, target_config, target_attr_name, source_config, source_attr_name, default_value=None):
|
||||
if hasattr(source_config, source_attr_name):
|
||||
if default_value is None or getattr(source_config, source_attr_name) != default_value:
|
||||
setattr(target_config, target_attr_name, getattr(source_config, source_attr_name))
|
||||
show_log(source_config, "Set {}.{}: {}.".format(target_name,
|
||||
target_attr_name, getattr(target_config, target_attr_name)))
|
||||
return target_config
|
||||
|
||||
class VLBert(PreTrainedModel):
|
||||
def __init__(self, bert_config, visual_config, cross_config, decoder_config, task_config):
|
||||
super(VLBert, self).__init__(bert_config, visual_config, cross_config, decoder_config)
|
||||
self.task_config = task_config
|
||||
self.ignore_video_index = -1
|
||||
|
||||
assert self.task_config.max_words <= bert_config.max_position_embeddings
|
||||
assert self.task_config.max_words <= decoder_config.max_target_embeddings
|
||||
assert self.task_config.max_frames <= visual_config.max_position_embeddings
|
||||
assert self.task_config.max_words + self.task_config.max_frames <= cross_config.max_position_embeddings
|
||||
|
||||
self._choice_sim = True
|
||||
self._cross_model = False
|
||||
|
||||
if hasattr(self.task_config, 'cross_model') and self.task_config.cross_model:
|
||||
self._choice_sim = False
|
||||
self._cross_model = self.task_config.cross_model
|
||||
|
||||
self.train_sim_after_cross = False
|
||||
if self._choice_sim and hasattr(self.task_config, 'train_sim_after_cross') and self.task_config.train_sim_after_cross:
|
||||
self.train_sim_after_cross = True
|
||||
show_log(task_config, "Test retrieval after cross encoder.")
|
||||
|
||||
self.with_sim_in_decoder = True
|
||||
if hasattr(self.task_config, 'without_sim_in_decoder') and self.task_config.without_sim_in_decoder:
|
||||
self.with_sim_in_decoder = False
|
||||
show_log(task_config, "Set no sim after cross.")
|
||||
|
||||
self.pretrain_without_decoder = False
|
||||
if hasattr(self.task_config, 'pretrain_without_decoder') and self.task_config.pretrain_without_decoder:
|
||||
self.pretrain_without_decoder = True
|
||||
show_log(task_config, "Set pretrain without decoder.")
|
||||
|
||||
show_log(task_config, "sim:{}, cross_model:{}".format(self._choice_sim, self._cross_model))
|
||||
|
||||
# The name should be `bert`, if dynamic vocabulary is needed
|
||||
bert_config = update_attr("bert_config", bert_config, "num_hidden_layers",
|
||||
self.task_config, "text_num_hidden_layers")
|
||||
self.bert = BertModel(bert_config)
|
||||
bert_word_embeddings_weight = self.bert.embeddings.word_embeddings.weight
|
||||
bert_position_embeddings_weight = self.bert.embeddings.position_embeddings.weight
|
||||
|
||||
visual_config = update_attr("visual_config", visual_config, "num_hidden_layers",
|
||||
self.task_config, "visual_num_hidden_layers")
|
||||
self.visual = VisualBertModel(visual_config)
|
||||
visual_word_embeddings_weight = self.visual.embeddings.word_embeddings.weight
|
||||
|
||||
if self._choice_sim is False or self.train_sim_after_cross:
|
||||
self.cls = BertOnlyMLMHead(bert_config, bert_word_embeddings_weight)
|
||||
self.cls_visual = VisualBertOnlyMLMHead(visual_config, visual_word_embeddings_weight)
|
||||
|
||||
cross_config = update_attr("cross_config", cross_config, "num_hidden_layers",
|
||||
self.task_config, "cross_num_hidden_layers")
|
||||
self.cross = CrossBertModel(cross_config)
|
||||
|
||||
if self.train_sim_after_cross is False:
|
||||
if self.task_config.do_pretrain and self.pretrain_without_decoder:
|
||||
show_log(task_config, "Pretrain without decoder.")
|
||||
else:
|
||||
decoder_config = update_attr("decoder_config", decoder_config, "num_decoder_layers",
|
||||
self.task_config, "decoder_num_hidden_layers")
|
||||
self.decoder = DecoderBertModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight)
|
||||
|
||||
self.normalize_video = NormalizeVideo(task_config)
|
||||
|
||||
mILNCELoss = MILNCELoss(batch_size=task_config.batch_size // task_config.n_gpu, n_pair=task_config.n_pair, )
|
||||
maxMarginRankingLoss = MaxMarginRankingLoss(margin=task_config.margin,
|
||||
negative_weighting=task_config.negative_weighting,
|
||||
batch_size=task_config.batch_size // task_config.n_gpu,
|
||||
n_pair=task_config.n_pair,
|
||||
hard_negative_rate=task_config.hard_negative_rate, )
|
||||
if task_config.use_mil:
|
||||
self.loss_fct = mILNCELoss
|
||||
show_log(task_config, "Using MILNCE.")
|
||||
else:
|
||||
self.loss_fct = maxMarginRankingLoss
|
||||
show_log(task_config, "Using Ranking Loss.")
|
||||
|
||||
if self._cross_model:
|
||||
self.loss_fct = CrossEn()
|
||||
show_log(task_config, "Use CE Loss.")
|
||||
|
||||
if self._choice_sim is False or self.train_sim_after_cross:
|
||||
self.similarity_dense = nn.Linear(bert_config.hidden_size, 1)
|
||||
|
||||
self.alm_loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
self.decoder_loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
|
||||
self.pretrain_with_joint_sim = False
|
||||
if hasattr(self.task_config, 'pretrain_with_joint_sim'):
|
||||
if self.task_config.pretrain_with_joint_sim:
|
||||
self.pretrain_with_joint_sim = True
|
||||
show_log(task_config, "Set pretrain with joint embedding.")
|
||||
if task_config.use_mil:
|
||||
self._pretrain_sim_loss_fct = mILNCELoss
|
||||
show_log(task_config, "Pretrain with joint embedding using MILNCE.")
|
||||
else:
|
||||
self._pretrain_sim_loss_fct = maxMarginRankingLoss
|
||||
show_log(task_config, "Pretrain with joint embedding using Ranking Loss.")
|
||||
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask, video, video_mask=None,
|
||||
pairs_masked_text=None, pairs_token_labels=None, masked_video=None, video_labels_index=None,
|
||||
input_caption_ids=None, decoder_mask=None, output_caption_ids=None):
|
||||
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1])
|
||||
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
|
||||
video_mask = video_mask.view(-1, video_mask.shape[-1])
|
||||
video = self.normalize_video(video)
|
||||
|
||||
if input_caption_ids is not None:
|
||||
input_caption_ids = input_caption_ids.view(-1, input_caption_ids.shape[-1])
|
||||
decoder_mask = decoder_mask.view(-1, decoder_mask.shape[-1])
|
||||
|
||||
sequence_output, visual_output = self.get_sequence_visual_output(input_ids, token_type_ids,
|
||||
attention_mask, video, video_mask, shaped=True)
|
||||
|
||||
if self.training:
|
||||
loss = 0.
|
||||
if self._choice_sim:
|
||||
sim_matrix = self.get_similarity_logits(sequence_output, visual_output, attention_mask, video_mask, shaped=True)
|
||||
sim_loss = self.loss_fct(sim_matrix)
|
||||
loss = loss + sim_loss
|
||||
|
||||
if self._cross_model and pairs_masked_text is not None and pairs_token_labels is not None:
|
||||
|
||||
if self.task_config.do_pretrain:
|
||||
pairs_masked_text = pairs_masked_text.view(-1, pairs_masked_text.shape[-1])
|
||||
pairs_token_labels = pairs_token_labels.view(-1, pairs_token_labels.shape[-1])
|
||||
|
||||
masked_video = self.normalize_video(masked_video)
|
||||
video_labels_index = video_labels_index.view(-1, video_labels_index.shape[-1])
|
||||
|
||||
sequence_output_alm, visual_output_alm = self.get_sequence_visual_output(pairs_masked_text, token_type_ids, attention_mask, masked_video, video_mask, shaped=True)
|
||||
|
||||
cross_output, pooled_output, concat_mask = self._get_cross_output(sequence_output_alm, visual_output_alm, attention_mask, video_mask)
|
||||
sequence_cross_output, visual_cross_output = torch.split(cross_output, [attention_mask.size(-1), video_mask.size(-1)], dim=1)
|
||||
|
||||
alm_loss = self._calculate_mlm_loss(sequence_cross_output, pairs_token_labels) # For mlm
|
||||
loss = loss + alm_loss
|
||||
|
||||
nce_loss = self._calculate_mfm_loss(visual_cross_output, video, video_mask, video_labels_index) # For mfm
|
||||
loss += nce_loss
|
||||
|
||||
if self.pretrain_with_joint_sim:
|
||||
sim_matrix = self.get_similarity_logits(sequence_output, visual_output, attention_mask, video_mask,
|
||||
shaped=True, _pretrain_joint=True)
|
||||
sim_loss_joint = self._pretrain_sim_loss_fct(sim_matrix)
|
||||
loss = loss + sim_loss_joint
|
||||
|
||||
if (input_caption_ids is not None) and \
|
||||
((self.task_config.do_pretrain and self.pretrain_without_decoder is False)
|
||||
or (self.task_config.do_pretrain is False and self.task_config.task_type == "caption")):
|
||||
if self.task_config.do_pretrain:
|
||||
decoder_scores, res_tuples = self._get_decoder_score(sequence_output_alm, visual_output_alm,
|
||||
input_ids, attention_mask, video_mask,
|
||||
input_caption_ids, decoder_mask, shaped=True)
|
||||
elif self.task_config.task_type == "caption":
|
||||
decoder_scores, res_tuples = self._get_decoder_score(sequence_output, visual_output,
|
||||
input_ids, attention_mask, video_mask,
|
||||
input_caption_ids, decoder_mask, shaped=True)
|
||||
|
||||
output_caption_ids = output_caption_ids.view(-1, output_caption_ids.shape[-1])
|
||||
decoder_loss = self.decoder_loss_fct(decoder_scores.view(-1, self.bert_config.vocab_size), output_caption_ids.view(-1))
|
||||
loss = loss + decoder_loss
|
||||
|
||||
if (self.with_sim_in_decoder and self.task_config.do_pretrain) or \
|
||||
self.task_config.task_type == "retrieval":
|
||||
if self.task_config.do_pretrain:
|
||||
sim_matrix_text_visual = self.get_similarity_logits(sequence_output_alm, visual_output_alm, attention_mask,
|
||||
video_mask, shaped=True)
|
||||
elif self.task_config.task_type == "retrieval":
|
||||
sim_matrix_text_visual = self.get_similarity_logits(sequence_output, visual_output, attention_mask,
|
||||
video_mask, shaped=True)
|
||||
sim_loss_text_visual = self.loss_fct(sim_matrix_text_visual)
|
||||
loss = loss + sim_loss_text_visual
|
||||
|
||||
return loss
|
||||
else:
|
||||
return None
|
||||
|
||||
def _calculate_mlm_loss(self, sequence_output_alm, pairs_token_labels):
|
||||
alm_scores = self.cls(sequence_output_alm)
|
||||
alm_loss = self.alm_loss_fct(alm_scores.view(-1, self.bert_config.vocab_size), pairs_token_labels.view(-1))
|
||||
return alm_loss
|
||||
|
||||
def _calculate_mfm_loss(self, visual_output_alm, video, video_mask, video_labels_index):
|
||||
afm_scores = self.cls_visual(visual_output_alm)
|
||||
afm_scores_tr = afm_scores.view(-1, afm_scores.shape[-1])
|
||||
|
||||
video_tr = video.permute(2, 0, 1)
|
||||
video_tr = video_tr.view(video_tr.shape[0], -1)
|
||||
|
||||
logits_matrix = torch.mm(afm_scores_tr, video_tr)
|
||||
video_mask_float = video_mask.to(dtype=torch.float)
|
||||
mask_matrix = torch.mm(video_mask_float.view(-1, 1), video_mask_float.view(1, -1))
|
||||
masked_logits = logits_matrix + (1. - mask_matrix) * -1e8
|
||||
|
||||
logpt = F.log_softmax(masked_logits, dim=-1)
|
||||
logpt = torch.diag(logpt)
|
||||
nce_loss = -logpt
|
||||
|
||||
video_labels_index_mask = (video_labels_index != self.ignore_video_index)
|
||||
nce_loss = nce_loss.masked_select(video_labels_index_mask.view(-1))
|
||||
nce_loss = nce_loss.mean()
|
||||
return nce_loss
|
||||
|
||||
# For inference
|
||||
def get_sequence_visual_output(self, input_ids, token_type_ids, attention_mask, video, video_mask, shaped=False):
|
||||
if shaped is False:
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.shape[-1])
|
||||
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
|
||||
video_mask = video_mask.view(-1, video_mask.shape[-1])
|
||||
video = self.normalize_video(video)
|
||||
|
||||
encoded_layers, _ = self.bert("bi", input_ids, token_type_ids, attention_mask, output_all_encoded_layers=True)
|
||||
sequence_output = encoded_layers[-1]
|
||||
|
||||
visual_layers, _ = self.visual("bi", video, video_mask, output_all_encoded_layers=True)
|
||||
visual_output = visual_layers[-1]
|
||||
|
||||
# cannot add any cross code here
|
||||
# beacause sequence and visual features are exclusive
|
||||
|
||||
return sequence_output, visual_output
|
||||
|
||||
def _mean_pooling_for_similarity(self, sequence_output, visual_output, attention_mask, video_mask,):
|
||||
attention_mask_un = attention_mask.to(dtype=torch.float).unsqueeze(-1)
|
||||
attention_mask_un[:, 0, :] = 0.
|
||||
sequence_output = sequence_output * attention_mask_un
|
||||
text_out = torch.sum(sequence_output, dim=1) / torch.sum(attention_mask_un, dim=1, dtype=torch.float)
|
||||
|
||||
video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1)
|
||||
visual_output = visual_output * video_mask_un
|
||||
video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float)
|
||||
video_mask_un_sum[video_mask_un_sum == 0.] = 1.
|
||||
video_out = torch.sum(visual_output, dim=1) / video_mask_un_sum
|
||||
|
||||
return text_out, video_out
|
||||
|
||||
def _cross_similarity(self, sequence_output, visual_output, attention_mask, video_mask):
|
||||
b_text, s_text, h_text = sequence_output.size()
|
||||
b_visual, s_visual, h_visual = visual_output.size()
|
||||
|
||||
retrieve_logits_list = []
|
||||
step_size = 5
|
||||
|
||||
split_size = [step_size] * (b_text // step_size)
|
||||
release_size = b_text - sum(split_size)
|
||||
if release_size > 0:
|
||||
split_size += [release_size]
|
||||
|
||||
sequence_output_splits = torch.split(sequence_output, split_size, dim=0)
|
||||
attention_mask_splits = torch.split(attention_mask, split_size, dim=0)
|
||||
for i in range(len(split_size)):
|
||||
sequence_output_row = sequence_output_splits[i]
|
||||
attention_mask_row = attention_mask_splits[i]
|
||||
sequence_output_l = sequence_output_row.unsqueeze(1).repeat(1, b_visual, 1, 1)
|
||||
sequence_output_l = sequence_output_l.view(-1, s_text, h_text)
|
||||
attention_mask_l = attention_mask_row.unsqueeze(1).repeat(1, b_visual, 1)
|
||||
attention_mask_l = attention_mask_l.view(-1, s_text)
|
||||
|
||||
step_truth = sequence_output_row.size(0)
|
||||
visual_output_r = visual_output.unsqueeze(0).repeat(step_truth, 1, 1, 1)
|
||||
visual_output_r = visual_output_r.view(-1, s_visual, h_visual)
|
||||
video_mask_r = video_mask.unsqueeze(0).repeat(step_truth, 1, 1)
|
||||
video_mask_r = video_mask_r.view(-1, s_visual)
|
||||
|
||||
cross_output, pooled_output, concat_mask = \
|
||||
self._get_cross_output(sequence_output_l, visual_output_r, attention_mask_l, video_mask_r)
|
||||
retrieve_logits_row = self.similarity_dense(pooled_output).squeeze(-1).view(step_truth, b_visual)
|
||||
|
||||
retrieve_logits_list.append(retrieve_logits_row)
|
||||
retrieve_logits = torch.cat(retrieve_logits_list, dim=0)
|
||||
return retrieve_logits
|
||||
|
||||
# For inference
|
||||
def get_similarity_logits(self, sequence_output, visual_output, attention_mask, video_mask, shaped=False, _pretrain_joint=False):
|
||||
if shaped is False:
|
||||
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
|
||||
video_mask = video_mask.view(-1, video_mask.shape[-1])
|
||||
|
||||
if self._cross_model and _pretrain_joint is False:
|
||||
retrieve_logits = self._cross_similarity(sequence_output, visual_output, attention_mask, video_mask)
|
||||
else:
|
||||
if self.train_sim_after_cross:
|
||||
retrieve_logits = self._cross_similarity(sequence_output, visual_output, attention_mask, video_mask)
|
||||
else:
|
||||
text_out, video_out = self._mean_pooling_for_similarity(sequence_output, visual_output, attention_mask, video_mask)
|
||||
# Do a cosine simlarity
|
||||
if self.task_config.use_mil is False:
|
||||
text_out = F.normalize(text_out, dim=-1)
|
||||
video_out = F.normalize(video_out, dim=-1)
|
||||
retrieve_logits = torch.matmul(text_out, video_out.t())
|
||||
|
||||
return retrieve_logits
|
||||
|
||||
def _get_cross_output(self, sequence_output, visual_output, attention_mask, video_mask):
|
||||
# Generate one pair (x, y) ===>
|
||||
concat_features = torch.cat((sequence_output, visual_output), dim=1) # concatnate tokens and frames
|
||||
concat_mask = torch.cat((attention_mask, video_mask), dim=1)
|
||||
text_type_ = torch.zeros_like(attention_mask)
|
||||
video_type_ = torch.ones_like(video_mask)
|
||||
concat_type = torch.cat((text_type_, video_type_), dim=1)
|
||||
# <===
|
||||
cross_layers, pooled_output = self.cross("bi", concat_features, concat_type, concat_mask, output_all_encoded_layers=True)
|
||||
cross_output = cross_layers[-1]
|
||||
|
||||
return cross_output, pooled_output, concat_mask
|
||||
|
||||
def _get_decoder_score(self, sequence_output, visual_output, input_ids, attention_mask, video_mask, input_caption_ids, decoder_mask, shaped=False):
|
||||
|
||||
if shaped is False:
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
|
||||
video_mask = video_mask.view(-1, video_mask.shape[-1])
|
||||
|
||||
input_caption_ids = input_caption_ids.view(-1, input_caption_ids.shape[-1])
|
||||
decoder_mask = decoder_mask.view(-1, decoder_mask.shape[-1])
|
||||
|
||||
res_tuples = ()
|
||||
cross_output, pooled_output, concat_mask = self._get_cross_output(sequence_output, visual_output, attention_mask, video_mask)
|
||||
decoder_scores = self.decoder(input_caption_ids, input_ids, encoder_outs=cross_output, answer_mask=decoder_mask, encoder_mask=concat_mask)
|
||||
|
||||
return decoder_scores, res_tuples
|
||||
|
||||
def decoder_caption(self, sequence_output, visual_output, input_ids, attention_mask, video_mask, input_caption_ids, decoder_mask,
|
||||
shaped=False, get_logits=False):
|
||||
"""
|
||||
:param get_logits: Used in Beam Search
|
||||
:return:
|
||||
"""
|
||||
if shaped is False:
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
|
||||
video_mask = video_mask.view(-1, video_mask.shape[-1])
|
||||
|
||||
input_caption_ids = input_caption_ids.view(-1, input_caption_ids.shape[-1])
|
||||
decoder_mask = decoder_mask.view(-1, decoder_mask.shape[-1])
|
||||
|
||||
assert self.pretrain_without_decoder is False
|
||||
decoder_scores, _ = self._get_decoder_score(sequence_output, visual_output,
|
||||
input_ids, attention_mask, video_mask, input_caption_ids, decoder_mask, shaped=True)
|
||||
|
||||
if get_logits:
|
||||
return decoder_scores
|
||||
|
||||
_, decoder_scores_result = torch.max(decoder_scores, -1)
|
||||
|
||||
return decoder_scores_result
|
|
@ -0,0 +1,180 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch optimization for BERT model."""
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.optimizer import required
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def warmup_cosine(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 0.5 * (1.0 + torch.cos(math.pi * x))
|
||||
|
||||
def warmup_constant(x, warmup=0.002):
|
||||
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
|
||||
Learning rate is 1. afterwards. """
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return 1.0
|
||||
|
||||
def warmup_linear(x, warmup=0.002):
|
||||
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
|
||||
After `t_total`-th training step, learning rate is zero. """
|
||||
if x < warmup:
|
||||
return x/warmup
|
||||
return max((x-1.)/(warmup-1.), 0)
|
||||
|
||||
SCHEDULES = {
|
||||
'warmup_cosine': warmup_cosine,
|
||||
'warmup_constant': warmup_constant,
|
||||
'warmup_linear': warmup_linear,
|
||||
}
|
||||
|
||||
|
||||
class BertAdam(Optimizer):
|
||||
"""Implements BERT version of Adam algorithm with weight decay fix.
|
||||
Params:
|
||||
lr: learning rate
|
||||
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
|
||||
t_total: total number of training steps for the learning
|
||||
rate schedule, -1 means constant learning rate. Default: -1
|
||||
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
|
||||
b1: Adams b1. Default: 0.9
|
||||
b2: Adams b2. Default: 0.999
|
||||
e: Adams epsilon. Default: 1e-6
|
||||
weight_decay: Weight decay. Default: 0.01
|
||||
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
||||
"""
|
||||
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
|
||||
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
|
||||
max_grad_norm=1.0):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if schedule not in SCHEDULES:
|
||||
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||
if not 0.0 <= b1 < 1.0:
|
||||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
||||
if not 0.0 <= b2 < 1.0:
|
||||
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
||||
if not e >= 0.0:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
||||
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
|
||||
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(BertAdam, self).__init__(params, defaults)
|
||||
|
||||
def get_lr(self):
|
||||
lr = []
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
return [0]
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
lr.append(lr_scheduled)
|
||||
return lr
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
warned_for_t_total = False
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['next_m'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['next_v'] = torch.zeros_like(p.data)
|
||||
|
||||
next_m, next_v = state['next_m'], state['next_v']
|
||||
beta1, beta2 = group['b1'], group['b2']
|
||||
|
||||
# Add grad clipping
|
||||
if group['max_grad_norm'] > 0:
|
||||
clip_grad_norm_(p, group['max_grad_norm'])
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
next_m.mul_(beta1).add_(1 - beta1, grad)
|
||||
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
update = next_m / (next_v.sqrt() + group['e'])
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
if group['weight_decay'] > 0.0:
|
||||
update += group['weight_decay'] * p.data
|
||||
|
||||
if group['t_total'] != -1:
|
||||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
progress = state['step']/group['t_total']
|
||||
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
||||
# warning for exceeding t_total (only active with warmup_linear
|
||||
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
|
||||
logger.warning(
|
||||
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
|
||||
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
|
||||
warned_for_t_total = True
|
||||
# end warning
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
|
||||
update_with_lr = lr_scheduled * update
|
||||
p.data.add_(-update_with_lr)
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
|
||||
# No bias correction
|
||||
# bias_correction1 = 1 - beta1 ** state['step']
|
||||
# bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
return loss
|
|
@ -0,0 +1,408 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import unicodedata
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
|
||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
|
||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
|
||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
|
||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
||||
}
|
||||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
'base-uncased': 512,
|
||||
'large-uncased': 512,
|
||||
'base-cased': 512,
|
||||
'large-cased': 512,
|
||||
'base-multilingual-uncased': 512,
|
||||
'base-multilingual-cased': 512,
|
||||
'base-chinese': 512,
|
||||
}
|
||||
VOCAB_NAME = 'vocab.txt'
|
||||
|
||||
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||
while True:
|
||||
token = reader.readline()
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
vocab[token] = index
|
||||
index += 1
|
||||
return vocab
|
||||
|
||||
|
||||
def whitespace_tokenize(text):
|
||||
"""Runs basic whitespace cleaning and splitting on a peice of text."""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
|
||||
class BertTokenizer(object):
|
||||
"""Runs end-to-end tokenization: punctuation splitting"""
|
||||
|
||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None, never_split=("[UNK]", "[SEP]", "[MASK]", "[CLS]")):
|
||||
if not os.path.isfile(vocab_file):
|
||||
raise ValueError(
|
||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
||||
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.ids_to_tokens = collections.OrderedDict(
|
||||
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, never_split=never_split)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
for token in self.basic_tokenizer.tokenize(text):
|
||||
for sub_token in self.wordpiece_tokenizer.tokenize(token):
|
||||
split_tokens.append(sub_token)
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
"""Converts a sequence of tokens into ids using the vocab."""
|
||||
ids = []
|
||||
for token in tokens:
|
||||
if token not in self.vocab:
|
||||
ids.append(self.vocab["[UNK]"])
|
||||
logger.error("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
|
||||
else:
|
||||
ids.append(self.vocab[token])
|
||||
if len(ids) > self.max_len:
|
||||
raise ValueError(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this BERT model ({} > {}). Running this"
|
||||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
|
||||
)
|
||||
return ids
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
"""Converts a sequence of ids in tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
tokens.append(self.ids_to_tokens[i])
|
||||
return tokens
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
vocab_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
|
||||
if os.path.exists(vocab_file) is False:
|
||||
if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
|
||||
else:
|
||||
vocab_file = pretrained_model_name
|
||||
if os.path.isdir(vocab_file):
|
||||
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
|
||||
# redirect to the cache, if necessary
|
||||
print(vocab_file)
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found. "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
vocab_file))
|
||||
return None
|
||||
if resolved_vocab_file == vocab_file:
|
||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||
else:
|
||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||
vocab_file, resolved_vocab_file))
|
||||
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||
# than the number of positional embeddings
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
kwargs['never_split'] = ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")
|
||||
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
|
||||
|
||||
return tokenizer
|
||||
|
||||
def add_tokens(self, new_tokens, model):
|
||||
"""
|
||||
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
||||
vocabulary, they are added to it with indices starting from length of the current vocabulary.
|
||||
Args:
|
||||
new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
|
||||
Returns:
|
||||
Number of tokens added to the vocabulary.
|
||||
Examples::
|
||||
# Let's see how to increase the vocabulary of Bert model and tokenizer
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = BertModel.from_pretrained('bert-base-uncased')
|
||||
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
|
||||
print('We have added', num_added_toks, 'tokens')
|
||||
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
||||
"""
|
||||
|
||||
to_add_tokens = []
|
||||
for token in new_tokens:
|
||||
assert isinstance(token, str)
|
||||
to_add_tokens.append(token)
|
||||
# logger.info("Adding %s to the vocabulary", token)
|
||||
|
||||
vocab = collections.OrderedDict()
|
||||
for token in self.vocab.keys():
|
||||
vocab[token] = self.vocab[token]
|
||||
for token in to_add_tokens:
|
||||
vocab[token] = len(vocab)
|
||||
self.vocab = self.wordpiece_tokenizer.vocab = vocab
|
||||
self.ids_to_tokens = collections.OrderedDict(
|
||||
[(ids, tok) for tok, ids in self.vocab.items()])
|
||||
|
||||
model.resize_token_embeddings(new_num_tokens=len(vocab))
|
||||
|
||||
class BasicTokenizer(object):
|
||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
||||
|
||||
def __init__(self, do_lower_case=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
|
||||
"""Constructs a BasicTokenizer.
|
||||
|
||||
Args:
|
||||
do_lower_case: Whether to lower case the input.
|
||||
"""
|
||||
self.do_lower_case = do_lower_case
|
||||
self.never_split = never_split
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = self._clean_text(text)
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
# matter since the English models were not trained on any Chinese data
|
||||
# and generally don't have any Chinese data in them (there are Chinese
|
||||
# characters in the vocabulary because Wikipedia does have some Chinese
|
||||
# words in the English Wikipedia.).
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
orig_tokens = whitespace_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case and token not in self.never_split:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token))
|
||||
|
||||
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
||||
return output_tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
if text in self.never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if _is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self._is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _is_chinese_char(self, cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or
|
||||
(cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
||||
continue
|
||||
if _is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
class WordpieceTokenizer(object):
|
||||
"""Runs WordPiece tokenization."""
|
||||
|
||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
|
||||
self.vocab = vocab
|
||||
self.unk_token = unk_token
|
||||
self.max_input_chars_per_word = max_input_chars_per_word
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text into its word pieces.
|
||||
|
||||
This uses a greedy longest-match-first algorithm to perform tokenization
|
||||
using the given vocabulary.
|
||||
|
||||
For example:
|
||||
input = "unaffable"
|
||||
output = ["un", "##aff", "##able"]
|
||||
|
||||
Args:
|
||||
text: A single token or whitespace separated tokens. This should have
|
||||
already been passed through `BasicTokenizer`.
|
||||
|
||||
Returns:
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
if len(chars) > self.max_input_chars_per_word:
|
||||
output_tokens.append(self.unk_token)
|
||||
continue
|
||||
|
||||
is_bad = False
|
||||
start = 0
|
||||
sub_tokens = []
|
||||
while start < len(chars):
|
||||
end = len(chars)
|
||||
cur_substr = None
|
||||
while start < end:
|
||||
substr = "".join(chars[start:end])
|
||||
if start > 0:
|
||||
substr = "##" + substr
|
||||
if substr in self.vocab:
|
||||
cur_substr = substr
|
||||
break
|
||||
end -= 1
|
||||
if cur_substr is None:
|
||||
is_bad = True
|
||||
break
|
||||
sub_tokens.append(cur_substr)
|
||||
start = end
|
||||
|
||||
if is_bad:
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
|
@ -0,0 +1,12 @@
|
|||
{
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 768,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"max_position_embeddings": 512,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 1,
|
||||
"vocab_size": 1024
|
||||
}
|
|
@ -0,0 +1,662 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'visual-bert-base': "[TODO]",
|
||||
'visual-bert-large': "[TODO]",
|
||||
}
|
||||
|
||||
CONFIG_NAME = 'visual_bert_config.json'
|
||||
WEIGHTS_NAME = 'visual_pytorch_model.bin'
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
class VisualBertConfig(object):
|
||||
"""Configuration class to store the configuration of a `VisualBertModel`.
|
||||
"""
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file=4096,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=3,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02):
|
||||
"""Constructs VisualBertConfig.
|
||||
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Size of the encoder layers and the pooler layer.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
probabilities.
|
||||
max_position_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
elif isinstance(vocab_size_or_config_json_file, int):
|
||||
self.vocab_size = vocab_size_or_config_json_file
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
else:
|
||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `VisualBertConfig` from a Python dictionary of parameters."""
|
||||
config = VisualBertConfig(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `VisualBertConfig` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as VisualBertLayerNorm
|
||||
except ImportError:
|
||||
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
||||
class VisualBertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(VisualBertLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
class VisualBertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(VisualBertEmbeddings, self).__init__()
|
||||
|
||||
self.word_embeddings = nn.Linear(config.vocab_size, config.hidden_size)
|
||||
# self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
# if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_embeddings):
|
||||
seq_length = input_embeddings.size(1)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_embeddings.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(input_embeddings.size(0), -1)
|
||||
|
||||
words_embeddings = self.word_embeddings(input_embeddings)
|
||||
# words_embeddings = self.transform_act_fn(words_embeddings)
|
||||
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = words_embeddings + position_embeddings
|
||||
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class VisualBertSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
return context_layer
|
||||
|
||||
|
||||
class VisualBertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertSelfOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VisualBertAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertAttention, self).__init__()
|
||||
self.self = VisualBertSelfAttention(config)
|
||||
self.output = VisualBertSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, attention_mask):
|
||||
self_output = self.self(input_tensor, attention_mask)
|
||||
attention_output = self.output(self_output, input_tensor)
|
||||
return attention_output
|
||||
|
||||
|
||||
class VisualBertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VisualBertOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VisualBertLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertLayer, self).__init__()
|
||||
self.attention = VisualBertAttention(config)
|
||||
self.intermediate = VisualBertIntermediate(config)
|
||||
self.output = VisualBertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class VisualBertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertEncoder, self).__init__()
|
||||
layer = VisualBertLayer(config)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
||||
all_encoder_layers = []
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states, attention_mask)
|
||||
if output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
if not output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
return all_encoder_layers
|
||||
|
||||
|
||||
class VisualBertPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertPooler, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class VisualBertPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertPredictionHeadTransform, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VisualBertLMPredictionHead(nn.Module):
|
||||
"""
|
||||
Note: It is a trap for `visual_bert_model_embedding_weights`, which is different from Embedding of BERT
|
||||
"""
|
||||
def __init__(self, config, visual_bert_model_embedding_weights):
|
||||
super(VisualBertLMPredictionHead, self).__init__()
|
||||
self.transform = VisualBertPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.weight = visual_bert_model_embedding_weights
|
||||
self.bias = nn.Parameter(torch.zeros(visual_bert_model_embedding_weights.size(1)))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = hidden_states.matmul(self.weight) + self.bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VisualBertOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config, visual_bert_model_embedding_weights):
|
||||
super(VisualBertOnlyMLMHead, self).__init__()
|
||||
self.predictions = VisualBertLMPredictionHead(config, visual_bert_model_embedding_weights)
|
||||
|
||||
def forward(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class VisualBertOnlyNSPHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertOnlyNSPHead, self).__init__()
|
||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def forward(self, pooled_output):
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return seq_relationship_score
|
||||
|
||||
|
||||
class VisualBertPreTrainingHeads(nn.Module):
|
||||
def __init__(self, config, visual_bert_model_embedding_weights):
|
||||
super(VisualBertPreTrainingHeads, self).__init__()
|
||||
self.predictions = VisualBertLMPredictionHead(config, visual_bert_model_embedding_weights)
|
||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def forward(self, sequence_output, pooled_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class PreTrainedVisualBertModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedVisualBertModel, self).__init__()
|
||||
if not isinstance(config, VisualBertConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `VisualBertConfig`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
def init_visual_bert_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, VisualBertLayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, pretrained_model_name, cache_dir, state_dict, task_config=None):
|
||||
'''
|
||||
abstract from `from_pretrained`
|
||||
'''
|
||||
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
|
||||
if os.path.exists(archive_file) is False:
|
||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
||||
else:
|
||||
archive_file = pretrained_model_name
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list. "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
archive_file))
|
||||
return None
|
||||
if resolved_archive_file == archive_file:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {}".format(archive_file))
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
tempdir = None
|
||||
if os.path.isdir(resolved_archive_file):
|
||||
serialization_dir = resolved_archive_file
|
||||
else:
|
||||
# Extract archive to temp dir
|
||||
tempdir = tempfile.mkdtemp()
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("extracting archive file {} to temp dir {}".format(
|
||||
resolved_archive_file, tempdir))
|
||||
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||
archive.extractall(tempdir)
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
config = VisualBertConfig.from_json_file(config_file)
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Model config {}".format(config))
|
||||
|
||||
if state_dict is None:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_path):
|
||||
state_dict = torch.load(weights_path)
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Weight doesn't exsits. {}".format(weights_path))
|
||||
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
return config, state_dict
|
||||
|
||||
@classmethod
|
||||
def init_preweight(cls, model, state_dict):
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(model, prefix='' if hasattr(model, 'visual_bert') else 'visual_bert.')
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None,
|
||||
cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedVisualBertModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `visual-bert-base`
|
||||
. `visual-bert-large`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `visual_bert_config.json` a configuration file for the model
|
||||
. `visual_pytorch_model.bin` a PyTorch dump of a VisualBertForPreTraining instance
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific VisualBert class
|
||||
(ex: num_classes for VisualBertForSequenceClassification)
|
||||
"""
|
||||
config, state_dict = PreTrainedVisualBertModel.get_config(pretrained_model_name, cache_dir, state_dict)
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
return model
|
||||
model = PreTrainedVisualBertModel.init_preweight(model, state_dict)
|
||||
|
||||
return model
|
||||
|
||||
class VisualBertModel(PreTrainedVisualBertModel):
|
||||
"""VisualBert model ("Bidirectional Embedding Representations from a Transformer").
|
||||
|
||||
Params:
|
||||
config: a VisualBertConfig class instance with the configuration to build a new model
|
||||
|
||||
Inputs:
|
||||
`type`: a str, indicates which masking will be used in the attention, choice from [`bi`, `seq`, `gen`]
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
||||
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
||||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see Bert paper for more details).
|
||||
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||
a batch has varying length sentences.
|
||||
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
||||
|
||||
Outputs: Tuple of (encoded_layers, pooled_output)
|
||||
`encoded_layers`: controled by `output_all_encoded_layers` argument:
|
||||
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
|
||||
of each attention block (i.e. 12 full sequences for VisualBert-base, 24 for VisualBert-large), each
|
||||
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
||||
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
||||
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
||||
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
||||
classifier pretrained on top of the hidden state associated to the first character of the
|
||||
input (`CLF`) to train on the Next-Sentence task (see Bert's paper).
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
|
||||
config = modeling.VisualBertConfig(vocab_size_or_config_json_file=4096, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = modeling.VisualBertModel(config=config)
|
||||
all_encoder_layers, pooled_output = model(video, video_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(VisualBertModel, self).__init__(config)
|
||||
self.embeddings = VisualBertEmbeddings(config)
|
||||
self.encoder = VisualBertEncoder(config)
|
||||
self.pooler = VisualBertPooler(config)
|
||||
self.apply(self.init_visual_bert_weights)
|
||||
|
||||
def forward(self, type, video, attention_mask=None, output_all_encoded_layers=True,
|
||||
attention_sentenceA_mask=None, attention_sentenceB_mask=None):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(video.size(0), video.size(1))
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(video)
|
||||
encoded_layers = self.encoder(embedding_output,
|
||||
extended_attention_mask,
|
||||
output_all_encoded_layers=output_all_encoded_layers)
|
||||
sequence_output = encoded_layers[-1]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
if not output_all_encoded_layers:
|
||||
encoded_layers = encoded_layers[-1]
|
||||
return encoded_layers, pooled_output
|
|
@ -0,0 +1,529 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (RandomSampler, SequentialSampler, TensorDataset)
|
||||
|
||||
# from torch.utils.data import DataLoader
|
||||
from data_prefetch_unitl import DataLoaderX as DataLoader # Enhanced Loader
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
from youcook_nopair_dataloader import Youcook_NoPair_DataLoader
|
||||
from metrics import compute_metrics, print_computed_metrics
|
||||
import pickle
|
||||
import logging
|
||||
import time
|
||||
import argparse
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import VLBert
|
||||
from pytorch_pretrained_bert.optimization_bert import BertAdam
|
||||
from util import parallel_apply
|
||||
|
||||
def get_logger(filename=None):
|
||||
logger = logging.getLogger('logger')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO)
|
||||
if filename is not None:
|
||||
handler = logging.FileHandler(filename)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
|
||||
logging.getLogger().addHandler(handler)
|
||||
return logger
|
||||
|
||||
global logger
|
||||
|
||||
def get_args(description='Youtube-Text-Video'):
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
|
||||
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
|
||||
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
|
||||
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
|
||||
parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
|
||||
parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
|
||||
parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
|
||||
parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
|
||||
parser.add_argument('--seed', type=int, default=42, help='random seed')
|
||||
parser.add_argument('--max_words', type=int, default=20, help='')
|
||||
parser.add_argument('--max_frames', type=int, default=100, help='')
|
||||
parser.add_argument('--min_words', type=int, default=0, help='')
|
||||
parser.add_argument('--feature_framerate', type=int, default=1, help='')
|
||||
parser.add_argument('--min_time', type=float, default=5.0, help='Gather small clips')
|
||||
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
|
||||
parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
|
||||
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
|
||||
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
|
||||
|
||||
parser.add_argument('--youcook', type=int, default=0, help='Train on YouCook2 data')
|
||||
parser.add_argument('--eval_youcook', type=int, default=0, help='Evaluate on YouCook2 data')
|
||||
parser.add_argument('--youcook_train_csv', type=str, default='data/youcookii_singlef_train.csv', help='')
|
||||
parser.add_argument('--youcook_val_csv', type=str, default='data/youcookii_singlef_val.csv', help='')
|
||||
parser.add_argument('--youcook_caption_path', type=str, default='data/youcookii_caption.pickle', help='youcookii caption pickle file path')
|
||||
parser.add_argument('--youcook_features_path_2D', type=str, default='data/youcookii_videos_feature2d', help='youcookii 2D feature path')
|
||||
parser.add_argument('--youcook_features_path_3D', type=str, default='data/youcookii_videos_feature3d', help='youcookii 3D feature path')
|
||||
|
||||
parser.add_argument('--pad_token', type=str, default='[PAD]', help='')
|
||||
|
||||
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
|
||||
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese.")
|
||||
parser.add_argument("--visual_bert_model", default="visual-bert-base", type=str, required=False,
|
||||
help="VisualBert pre-trained model selected in the list: visual-bert-base, visual-bert-large")
|
||||
parser.add_argument("--cross_bert_model", default="cross-bert-base", type=str, required=False,
|
||||
help="CrossBert pre-trained model selected in the list: cross-bert-base, cross-bert-large")
|
||||
parser.add_argument("--decoder_bert_model", default="decoder-bert-base", type=str, required=False,
|
||||
help="DecoderBert pre-trained model selected in the list: decoder-bert-base, decoder-bert-large")
|
||||
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
|
||||
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.");
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
|
||||
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
|
||||
parser.add_argument("--task_type", default="retrieval", type=str, help="Point the task `retrieval` to finetune.")
|
||||
parser.add_argument("--datatype", default="youcook", type=str, help="Point the dataset `youcook` to finetune.")
|
||||
parser.add_argument('--coef_lr', type=float, default=0.1, help='coefficient for bert branch.')
|
||||
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
|
||||
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether MIL, has a high priority than use_mil.")
|
||||
|
||||
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
|
||||
parser.add_argument('--visual_num_hidden_layers', type=int, default=1, help="Layer NO. of visual.")
|
||||
parser.add_argument('--cross_num_hidden_layers', type=int, default=2, help="Layer NO. of cross.")
|
||||
parser.add_argument('--decoder_num_hidden_layers', type=int, default=1, help="Layer NO. of decoder.")
|
||||
|
||||
parser.add_argument('--train_sim_after_cross', action='store_true', help="Test retrieval after cross encoder.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sampled_use_mil: # sample from each video, has a high priority than use_mil.
|
||||
args.use_mil = True
|
||||
|
||||
# Check paramenters
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
args.gradient_accumulation_steps))
|
||||
if not args.do_pretrain and not args.do_train and not args.do_eval:
|
||||
raise ValueError("At least one of `do_pretrain` or `do_train` or `do_eval` must be True.")
|
||||
|
||||
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
|
||||
|
||||
return args
|
||||
|
||||
def set_seed_logger(args):
|
||||
global logger
|
||||
# predefining random initial seeds
|
||||
random.seed(args.seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
|
||||
logger.info("Effective parameters:")
|
||||
for key in sorted(args.__dict__):
|
||||
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
|
||||
|
||||
def init_device(args):
|
||||
global logger
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
logger.info("device: {} n_gpu: {}".format(device, n_gpu))
|
||||
args.n_gpu = n_gpu
|
||||
|
||||
if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0:
|
||||
raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format(
|
||||
args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu))
|
||||
|
||||
return device, n_gpu
|
||||
|
||||
def init_model(args, device, n_gpu):
|
||||
|
||||
if args.init_model:
|
||||
model_state_dict = torch.load(args.init_model, map_location='cpu')
|
||||
else:
|
||||
model_state_dict = None
|
||||
|
||||
# Prepare model
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
||||
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
|
||||
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
||||
|
||||
model.to(device)
|
||||
|
||||
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
|
||||
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
|
||||
# remove the need for this code, but it is still valid.
|
||||
if args.fp16:
|
||||
try:
|
||||
import apex
|
||||
apex.amp.register_half_function(torch, 'einsum')
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
return model
|
||||
|
||||
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, coef_lr=1.0):
|
||||
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
|
||||
no_decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)]
|
||||
decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)]
|
||||
|
||||
no_decay_bert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." in n]
|
||||
no_decay_nobert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." not in n]
|
||||
|
||||
decay_bert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." in n]
|
||||
decay_nobert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." not in n]
|
||||
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in no_decay_bert_param_tp], 'weight_decay': 0.01, 'lr': args.lr * coef_lr},
|
||||
{'params': [p for n, p in no_decay_nobert_param_tp], 'weight_decay': 0.01},
|
||||
{'params': [p for n, p in decay_bert_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
|
||||
{'params': [p for n, p in decay_nobert_param_tp], 'weight_decay': 0.0}
|
||||
]
|
||||
scheduler = None
|
||||
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
|
||||
schedule='warmup_linear', t_total=num_train_optimization_steps, weight_decay=0.01,
|
||||
max_grad_norm=1.0)
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
return optimizer, scheduler, model
|
||||
|
||||
def dataloader_youcook_train(args, tokenizer):
|
||||
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
|
||||
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
|
||||
logger.info('Done, caption length: {}'.format(len(caption)))
|
||||
|
||||
args.n_pair = 1
|
||||
youcook_dataset = Youcook_NoPair_DataLoader(
|
||||
csv=args.youcook_train_csv,
|
||||
features_path=args.youcook_features_path_2D,
|
||||
features_path_3D=args.youcook_features_path_3D,
|
||||
caption=caption,
|
||||
min_time=args.min_time,
|
||||
max_words=args.max_words,
|
||||
min_words=args.min_words,
|
||||
feature_framerate=args.feature_framerate,
|
||||
tokenizer=tokenizer,
|
||||
n_pair=args.n_pair,
|
||||
pad_token=args.pad_token,
|
||||
max_frames=args.max_frames,
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
youcook_dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_thread_reader,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
return dataloader, len(youcook_dataset)
|
||||
|
||||
def dataloader_youcook_test(args, tokenizer):
|
||||
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
|
||||
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
|
||||
logger.info('Done, caption length: {}'.format(len(caption)))
|
||||
|
||||
youcook_testset = Youcook_NoPair_DataLoader(
|
||||
csv=args.youcook_val_csv,
|
||||
features_path=args.youcook_features_path_2D,
|
||||
features_path_3D=args.youcook_features_path_3D,
|
||||
caption=caption,
|
||||
min_time=args.min_time,
|
||||
max_words=args.max_words,
|
||||
min_words=args.min_words,
|
||||
feature_framerate=args.feature_framerate,
|
||||
tokenizer=tokenizer,
|
||||
n_pair=-1,
|
||||
pad_token=args.pad_token,
|
||||
max_frames=args.max_frames,
|
||||
)
|
||||
|
||||
test_sampler = SequentialSampler(youcook_testset)
|
||||
dataloader_youcook = DataLoader(
|
||||
youcook_testset,
|
||||
sampler=test_sampler,
|
||||
batch_size=args.batch_size_val,
|
||||
num_workers=args.num_thread_reader,
|
||||
pin_memory=False,
|
||||
)
|
||||
logger.info('YoucookII validation pairs: {}'.format(len(youcook_testset)))
|
||||
|
||||
return dataloader_youcook, len(youcook_testset)
|
||||
|
||||
def save_model(epoch, args, model, type_name=""):
|
||||
# Only save the model it-self
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
output_model_file = os.path.join(
|
||||
args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
logger.info("Model saved to %s", output_model_file)
|
||||
return output_model_file
|
||||
|
||||
def load_model(epoch, args, n_gpu, device, model_file=None):
|
||||
if model_file is None or len(model_file) == 0:
|
||||
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
|
||||
if os.path.exists(model_file):
|
||||
model_state_dict = torch.load(model_file, map_location='cpu')
|
||||
logger.info("Model loaded from %s", model_file)
|
||||
# Prepare model
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
||||
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
|
||||
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
||||
|
||||
model.to(device)
|
||||
else:
|
||||
model = None
|
||||
return model
|
||||
|
||||
def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step):
|
||||
global logger
|
||||
torch.cuda.empty_cache()
|
||||
model.train()
|
||||
log_step = 100
|
||||
start_time = time.time()
|
||||
total_loss = 0
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if n_gpu == 1:
|
||||
# multi-gpu does scattering it-self
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
|
||||
input_ids, input_mask, segment_ids, video, video_mask, \
|
||||
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index = batch
|
||||
loss = model(input_ids, segment_ids, input_mask, video, video_mask,
|
||||
pairs_masked_text=pairs_masked_text, pairs_token_labels=pairs_token_labels,
|
||||
masked_video=masked_video, video_labels_index=video_labels_index)
|
||||
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
total_loss += float(loss)
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step() # Update learning rate schedule
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
global_step += 1
|
||||
if global_step % log_step == 0:
|
||||
logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
|
||||
args.epochs, step + 1,
|
||||
len(train_dataloader), "-".join([str('%.6f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]), loss,
|
||||
(time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
|
||||
start_time = time.time()
|
||||
|
||||
total_loss = total_loss / len(train_dataloader)
|
||||
return total_loss, global_step
|
||||
|
||||
def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list):
|
||||
sim_matrix = []
|
||||
for idx1, b1 in enumerate(batch_list_t):
|
||||
input_ids, input_mask, segment_ids, _, _, _, _, _, _ = b1
|
||||
sequence_output = batch_sequence_output_list[idx1]
|
||||
each_row = []
|
||||
for idx2, b2 in enumerate(batch_list_v):
|
||||
_, _, _, video, video_mask, _, _, _, _ = b2
|
||||
visual_output = batch_visual_output_list[idx2]
|
||||
b1b2_logits = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask)
|
||||
b1b2_logits = b1b2_logits.cpu().detach().numpy()
|
||||
each_row.append(b1b2_logits)
|
||||
each_row = np.concatenate(tuple(each_row), axis=-1)
|
||||
sim_matrix.append(each_row)
|
||||
return sim_matrix
|
||||
|
||||
def eval_epoch(args, model, test_dataloader, device, n_gpu):
|
||||
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module.to(device)
|
||||
else:
|
||||
model = model.to(device)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
batch_list = []
|
||||
batch_sequence_output_list, batch_visual_output_list = [], []
|
||||
for bid, batch in enumerate(test_dataloader):
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
|
||||
input_ids, input_mask, segment_ids, video, video_mask, _, _, _, _ = batch
|
||||
sequence_output, visual_output = model.get_sequence_visual_output(input_ids, segment_ids, input_mask, video, video_mask)
|
||||
|
||||
batch_sequence_output_list.append(sequence_output)
|
||||
batch_visual_output_list.append(visual_output)
|
||||
batch_list.append(batch)
|
||||
|
||||
print("{}/{}\r".format(bid, len(test_dataloader)), end="")
|
||||
|
||||
start_time = time.time()
|
||||
if n_gpu > 1:
|
||||
device_ids = list(range(n_gpu))
|
||||
batch_list_t_splits = []
|
||||
batch_list_v_splits = []
|
||||
batch_t_output_splits = []
|
||||
batch_v_output_splits = []
|
||||
bacth_len = len(batch_list)
|
||||
split_len = (bacth_len + n_gpu - 1) // n_gpu
|
||||
for dev_id in device_ids:
|
||||
s_, e_ = dev_id * split_len, (dev_id + 1) * split_len
|
||||
if dev_id == 0:
|
||||
batch_list_t_splits.append(batch_list[s_:e_])
|
||||
batch_list_v_splits.append(batch_list)
|
||||
|
||||
batch_t_output_splits.append(batch_sequence_output_list[s_:e_])
|
||||
batch_v_output_splits.append(batch_visual_output_list)
|
||||
else:
|
||||
devc = torch.device('cuda:{}'.format(str(dev_id)))
|
||||
devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list[s_:e_]]
|
||||
batch_list_t_splits.append(devc_batch_list)
|
||||
devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list]
|
||||
batch_list_v_splits.append(devc_batch_list)
|
||||
|
||||
devc_batch_list = [b.to(devc) for b in batch_sequence_output_list[s_:e_]]
|
||||
batch_t_output_splits.append(devc_batch_list)
|
||||
devc_batch_list = [b.to(devc) for b in batch_visual_output_list]
|
||||
batch_v_output_splits.append(devc_batch_list)
|
||||
|
||||
parameters_tuple_list = [(batch_list_t_splits[dev_id], batch_list_v_splits[dev_id],
|
||||
batch_t_output_splits[dev_id], batch_v_output_splits[dev_id]) for dev_id in device_ids]
|
||||
parallel_outputs = parallel_apply(_run_on_single_gpu, model, parameters_tuple_list, device_ids)
|
||||
sim_matrix = []
|
||||
for idx in range(len(parallel_outputs)):
|
||||
sim_matrix += parallel_outputs[idx]
|
||||
else:
|
||||
sim_matrix = _run_on_single_gpu(model, batch_list, batch_list, batch_sequence_output_list, batch_visual_output_list)
|
||||
logger.info("%f" % (time.time() - start_time))
|
||||
|
||||
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
|
||||
metrics = compute_metrics(sim_matrix)
|
||||
logger.info('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0])))
|
||||
logger.info('\t>>> R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.
|
||||
format(metrics['R1'], metrics['R5'], metrics['R10'], metrics['MR']))
|
||||
|
||||
R1 = metrics['R1']
|
||||
return R1
|
||||
|
||||
DATALOADER_DICT = {}
|
||||
DATALOADER_DICT["youcook"] = {"train":dataloader_youcook_train, "val":dataloader_youcook_test}
|
||||
|
||||
def main():
|
||||
global logger
|
||||
args = get_args()
|
||||
set_seed_logger(args)
|
||||
device, n_gpu = init_device(args)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
if args.do_train:
|
||||
args.task_type = "retrieval"
|
||||
model = init_model(args, device, n_gpu)
|
||||
|
||||
datatype = args.datatype
|
||||
assert datatype in DATALOADER_DICT
|
||||
test_dataloader, test_length = DATALOADER_DICT[datatype]["val"](args, tokenizer)
|
||||
logger.info("***** Running test *****")
|
||||
logger.info(" Num examples = %d", test_length)
|
||||
logger.info(" Batch size = %d", args.batch_size_val)
|
||||
logger.info(" Num steps = %d", len(test_dataloader))
|
||||
|
||||
if args.do_pretrain:
|
||||
raise NotImplementedError("Deleted~ no use")
|
||||
elif args.do_train:
|
||||
|
||||
train_dataloader, train_length = DATALOADER_DICT[datatype]["train"](args, tokenizer)
|
||||
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
|
||||
/ args.gradient_accumulation_steps) * args.epochs
|
||||
|
||||
coef_lr = 0.1
|
||||
if args.init_model:
|
||||
coef_lr = 1.0
|
||||
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, coef_lr=coef_lr)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", train_length)
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||
|
||||
best_score = 0.00001
|
||||
best_output_model_file = None
|
||||
global_step = 0
|
||||
for epoch in range(args.epochs):
|
||||
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
|
||||
scheduler, global_step)
|
||||
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
|
||||
output_model_file = save_model(epoch, args, model, type_name="")
|
||||
|
||||
R1 = eval_epoch(args, model, test_dataloader, device, n_gpu)
|
||||
if best_score <= R1:
|
||||
best_score = R1
|
||||
best_output_model_file = output_model_file
|
||||
logger.info("The best model is: {}, the R1 is: {:.4f}".format(best_output_model_file, best_score))
|
||||
model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
|
||||
eval_epoch(args, model, test_dataloader, device, n_gpu)
|
||||
elif args.do_eval:
|
||||
eval_epoch(args, model, test_dataloader, device, n_gpu)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,850 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (RandomSampler, SequentialSampler, TensorDataset)
|
||||
|
||||
from data_prefetch_unitl import DataLoaderX as DataLoader # Enhanced Loader
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from youcook_transcript_nopair_dataloader import Youcook_Transcript_NoPair_DataLoader
|
||||
from youtube_transcript_dataloader import Youtube_Transcript_DataLoader
|
||||
from rouge import Rouge
|
||||
from nlgeval import compute_metrics, NLGEval
|
||||
import pickle
|
||||
import logging
|
||||
import time
|
||||
import argparse
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import VLBert
|
||||
from pytorch_pretrained_bert.optimization_bert import BertAdam
|
||||
from pytorch_pretrained_bert.beam import Beam
|
||||
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
|
||||
global amp_pck_loaded_
|
||||
try:
|
||||
from apex import amp
|
||||
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
|
||||
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
|
||||
# remove the need for this code, but it is still valid.
|
||||
amp.register_half_function(torch, 'einsum')
|
||||
amp_pck_loaded_ = True
|
||||
except ImportError:
|
||||
amp_pck_loaded_ = False
|
||||
|
||||
def get_logger(filename=None):
|
||||
logger = logging.getLogger('logger')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO)
|
||||
if filename is not None:
|
||||
handler = logging.FileHandler(filename)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
|
||||
logging.getLogger().addHandler(handler)
|
||||
return logger
|
||||
|
||||
global logger
|
||||
|
||||
def get_args(description='Youtube-Text-Video'):
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument('--train_csv', type=str, default='data/HowTo100M_v1.csv', help='train csv')
|
||||
parser.add_argument('--features_path_2D', type=str, default='feature_2d', help='feature path for 2D features')
|
||||
parser.add_argument('--features_path_3D', type=str, default='feature_3d', help='feature path for 3D features')
|
||||
parser.add_argument('--caption_path', type=str, default='data/caption.pickle', help='caption pickle file path')
|
||||
|
||||
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
|
||||
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
|
||||
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
|
||||
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
|
||||
parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
|
||||
parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
|
||||
parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
|
||||
parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
|
||||
parser.add_argument('--seed', type=int, default=42, help='random seed')
|
||||
parser.add_argument('--max_words', type=int, default=20, help='')
|
||||
parser.add_argument('--max_frames', type=int, default=100, help='')
|
||||
parser.add_argument('--min_words', type=int, default=0, help='')
|
||||
parser.add_argument('--feature_framerate', type=int, default=1, help='')
|
||||
parser.add_argument('--min_time', type=float, default=5.0, help='Gather small clips')
|
||||
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
|
||||
parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
|
||||
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
|
||||
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
|
||||
|
||||
parser.add_argument('--youcook', type=int, default=0, help='Train on YouCook2 data')
|
||||
parser.add_argument('--eval_youcook', type=int, default=0, help='Evaluate on YouCook2 data')
|
||||
parser.add_argument('--youcook_train_csv', type=str, default='data/youcookii_singlef_train.csv', help='')
|
||||
parser.add_argument('--youcook_val_csv', type=str, default='data/youcookii_singlef_val.csv', help='')
|
||||
parser.add_argument('--youcook_caption_path', type=str, default='data/youcookii_caption_transcript.pickle', help='youcookii caption and transcription pickle file path')
|
||||
parser.add_argument('--youcook_features_path_2D', type=str, default='data/youcookii_videos_feature2d', help='youcookii feature path for 2D features')
|
||||
parser.add_argument('--youcook_features_path_3D', type=str, default='data/youcookii_videos_feature3d', help='youcookii feature path for 3D features')
|
||||
|
||||
parser.add_argument('--pad_token', type=str, default='[PAD]', help='')
|
||||
|
||||
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
|
||||
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese.")
|
||||
parser.add_argument("--visual_bert_model", default="visual-bert-base", type=str, required=False,
|
||||
help="VisualBert pre-trained model selected in the list: visual-bert-base, visual-bert-large")
|
||||
parser.add_argument("--cross_bert_model", default="cross-bert-base", type=str, required=False,
|
||||
help="CrossBert pre-trained model selected in the list: cross-bert-base, cross-bert-large")
|
||||
parser.add_argument("--decoder_bert_model", default="decoder-bert-base", type=str, required=False,
|
||||
help="DecoderBert pre-trained model selected in the list: decoder-bert-base, decoder-bert-large")
|
||||
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
|
||||
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.");
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
|
||||
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
|
||||
parser.add_argument("--task_type", default="caption", type=str, help="Point the task `retrieval` or `caption` to finetune.")
|
||||
parser.add_argument("--datatype", default="youcook", type=str, help="Point the dataset `youcook` to finetune.")
|
||||
|
||||
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
|
||||
parser.add_argument('--coef_lr', type=float, default=0.1, help='coefficient for bert branch.')
|
||||
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
|
||||
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether use MIL, has a high priority than use_mil.")
|
||||
|
||||
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
|
||||
parser.add_argument('--visual_num_hidden_layers', type=int, default=1, help="Layer NO. of visual.")
|
||||
parser.add_argument('--cross_num_hidden_layers', type=int, default=2, help="Layer NO. of cross.")
|
||||
parser.add_argument('--decoder_num_hidden_layers', type=int, default=1, help="Layer NO. of decoder.")
|
||||
|
||||
parser.add_argument('--cross_model', action='store_true', help="Whether training with decoder.")
|
||||
parser.add_argument('--pretrain_with_joint_sim', action='store_true', help="Whether using joint embedding when pretraining.")
|
||||
parser.add_argument('--pretrain_enhance_vmodal', action='store_true', help="Enhance visual and other modalities when pretraining.")
|
||||
parser.add_argument('--without_sim_in_decoder', action='store_true', help="Whether align in decoder when training.")
|
||||
parser.add_argument('--pretrain_without_decoder', action='store_true', help="Whether ignore decode when pretraining.")
|
||||
|
||||
parser.add_argument("--load_checkpoint", action="store_true")
|
||||
parser.add_argument("--checkpoint_model", default="pytorch_model.bin.checkpoint", type=str, required=False,
|
||||
help="Save the last model as a checkpoint.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sampled_use_mil: # sample from each video, has a high priority than use_mil.
|
||||
args.use_mil = True
|
||||
if args.do_pretrain is False:
|
||||
args.pretrain_without_decoder = False
|
||||
|
||||
global amp_pck_loaded_
|
||||
if args.fp16:
|
||||
if not amp_pck_loaded_:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
# Check paramenters
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
args.gradient_accumulation_steps))
|
||||
if not args.do_pretrain and not args.do_train and not args.do_eval:
|
||||
raise ValueError("At least one of `do_pretrain` or `do_train` or `do_eval` must be True.")
|
||||
|
||||
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
|
||||
|
||||
args.checkpoint_model = '{}_{}_{}_{}.checkpoint'.format(args.checkpoint_model, args.bert_model, args.max_words, args.max_frames)
|
||||
|
||||
return args
|
||||
|
||||
def set_seed_logger(args):
|
||||
global logger
|
||||
# predefining random initial seeds
|
||||
random.seed(args.seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
# torch.multiprocessing.set_sharing_strategy('file_system') # RuntimeError: received 0 items of ancdata.
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
local_rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(local_rank)
|
||||
args.local_rank = local_rank
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
|
||||
|
||||
if local_rank == 0:
|
||||
logger.info("Effective parameters:")
|
||||
for key in sorted(args.__dict__):
|
||||
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
|
||||
|
||||
return args, world_size, local_rank
|
||||
|
||||
def init_device(args, local_rank):
|
||||
global logger
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank)
|
||||
|
||||
n_gpu = torch.cuda.device_count()
|
||||
logger.info("device: {} n_gpu: {}".format(device, n_gpu))
|
||||
args.n_gpu = n_gpu
|
||||
|
||||
if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0:
|
||||
raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format(
|
||||
args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu))
|
||||
|
||||
return device, n_gpu
|
||||
|
||||
def init_model(args, device, n_gpu, local_rank):
|
||||
|
||||
if args.init_model:
|
||||
model_state_dict = torch.load(args.init_model, map_location='cpu')
|
||||
else:
|
||||
model_state_dict = None
|
||||
|
||||
# Prepare model
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
||||
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
|
||||
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
||||
|
||||
model.to(device)
|
||||
|
||||
return model
|
||||
|
||||
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.):
|
||||
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
|
||||
no_decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)]
|
||||
decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)]
|
||||
|
||||
no_decay_bert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." in n]
|
||||
no_decay_nobert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." not in n]
|
||||
|
||||
decay_bert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." in n]
|
||||
decay_nobert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." not in n]
|
||||
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in no_decay_bert_param_tp], 'weight_decay': 0.01, 'lr': args.lr * coef_lr},
|
||||
{'params': [p for n, p in no_decay_nobert_param_tp], 'weight_decay': 0.01},
|
||||
{'params': [p for n, p in decay_bert_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
|
||||
{'params': [p for n, p in decay_nobert_param_tp], 'weight_decay': 0.0}
|
||||
]
|
||||
|
||||
scheduler = None
|
||||
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
|
||||
schedule='warmup_linear', t_total=num_train_optimization_steps, weight_decay=0.01,
|
||||
max_grad_norm=1.0)
|
||||
|
||||
if args.fp16:
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if n_gpu > 1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
|
||||
output_device=local_rank, find_unused_parameters=True)
|
||||
|
||||
return optimizer, scheduler, model
|
||||
|
||||
def dataloader_youcook_train(args, tokenizer):
|
||||
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
|
||||
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
|
||||
logger.info('Done, caption length: {}'.format(len(caption)))
|
||||
|
||||
youcook_dataset = Youcook_Transcript_NoPair_DataLoader(
|
||||
csv=args.youcook_train_csv,
|
||||
features_path=args.youcook_features_path_2D,
|
||||
features_path_3D=args.youcook_features_path_3D,
|
||||
caption=caption,
|
||||
min_time=args.min_time,
|
||||
max_words=args.max_words,
|
||||
min_words=args.min_words,
|
||||
feature_framerate=args.feature_framerate,
|
||||
tokenizer=tokenizer,
|
||||
n_pair=-1,
|
||||
pad_token=args.pad_token,
|
||||
max_frames=args.max_frames,
|
||||
)
|
||||
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(youcook_dataset)
|
||||
dataloader = DataLoader(
|
||||
youcook_dataset,
|
||||
batch_size=args.batch_size // args.n_gpu,
|
||||
num_workers=args.num_thread_reader,
|
||||
pin_memory=False,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
return dataloader, len(youcook_dataset), train_sampler
|
||||
|
||||
def dataloader_youcook_test(args, tokenizer):
|
||||
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
|
||||
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
|
||||
logger.info('Done, caption length: {}'.format(len(caption)))
|
||||
|
||||
youcook_testset = Youcook_Transcript_NoPair_DataLoader(
|
||||
csv=args.youcook_val_csv,
|
||||
features_path=args.youcook_features_path_2D,
|
||||
features_path_3D=args.youcook_features_path_3D,
|
||||
caption=caption,
|
||||
min_time=args.min_time,
|
||||
max_words=args.max_words,
|
||||
min_words=args.min_words,
|
||||
feature_framerate=args.feature_framerate,
|
||||
tokenizer=tokenizer,
|
||||
n_pair=-1,
|
||||
pad_token=args.pad_token,
|
||||
max_frames=args.max_frames,
|
||||
)
|
||||
|
||||
test_sampler = SequentialSampler(youcook_testset)
|
||||
dataloader_youcook = DataLoader(
|
||||
youcook_testset,
|
||||
sampler=test_sampler,
|
||||
batch_size=args.batch_size_val,
|
||||
num_workers=args.num_thread_reader,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
logger.info('YoucookII validation pairs: {}'.format(len(youcook_testset)))
|
||||
return dataloader_youcook, len(youcook_testset)
|
||||
|
||||
def dataloader_pretrain(args, tokenizer, only_sim=False):
|
||||
logger.info('Loading captions: {}'.format(args.caption_path))
|
||||
caption = pickle.load(open(args.caption_path, 'rb'))
|
||||
logger.info('Done, caption length: {}'.format(len(caption)))
|
||||
|
||||
dataset = Youtube_Transcript_DataLoader(
|
||||
csv=args.train_csv,
|
||||
features_path=args.features_path_2D,
|
||||
features_path_3D=args.features_path_3D,
|
||||
caption=caption,
|
||||
min_time=args.min_time,
|
||||
max_words=args.max_words,
|
||||
min_words=args.min_words,
|
||||
feature_framerate=args.feature_framerate,
|
||||
tokenizer=tokenizer,
|
||||
n_pair=args.n_pair,
|
||||
pad_token=args.pad_token,
|
||||
max_frames=args.max_frames,
|
||||
use_mil=args.use_mil,
|
||||
only_sim=only_sim,
|
||||
sampled_use_mil=args.sampled_use_mil,
|
||||
pretrain_enhance_vmodal=args.pretrain_enhance_vmodal,
|
||||
video_dim=args.video_dim,
|
||||
)
|
||||
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size // args.n_gpu,
|
||||
num_workers=args.num_thread_reader,
|
||||
pin_memory=False,
|
||||
shuffle=(sampler is None),
|
||||
sampler=sampler,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
return dataloader, len(dataset), sampler
|
||||
|
||||
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
|
||||
if isinstance(state_dict, dict):
|
||||
cpu_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
cpu_dict[k] = convert_state_dict_type(v)
|
||||
return cpu_dict
|
||||
elif isinstance(state_dict, list):
|
||||
return [convert_state_dict_type(v) for v in state_dict]
|
||||
elif torch.is_tensor(state_dict):
|
||||
return state_dict.type(ttype)
|
||||
else:
|
||||
return state_dict
|
||||
|
||||
def save_model(epoch, args, model, local_rank, type_name="", global_step=-1, optimizer=None):
|
||||
# Only save the model it-self
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
output_model_file = os.path.join(
|
||||
args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
logger.info("Model saved to %s", output_model_file)
|
||||
|
||||
if global_step != -1 and optimizer is not None:
|
||||
amp_state_dict = {}
|
||||
if args.fp16:
|
||||
amp_state_dict = amp.state_dict()
|
||||
state_dict = {
|
||||
'epoch': epoch,
|
||||
'global_step': global_step,
|
||||
'model_state_dict': model_to_save.state_dict(),
|
||||
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
|
||||
'amp_state_dict': amp_state_dict,
|
||||
}
|
||||
checkpoint_model_file = os.path.join(args.output_dir, args.checkpoint_model)
|
||||
torch.save(state_dict, checkpoint_model_file)
|
||||
logger.info("Checkpoint is saved. use `load_checkpoint` to recovery it.")
|
||||
|
||||
return output_model_file
|
||||
|
||||
def load_model(epoch, args, n_gpu, device, model, global_step=0, model_file=None):
|
||||
if model_file is None or len(model_file) == 0:
|
||||
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
|
||||
|
||||
last_optim_state = None
|
||||
amp_state_dict = None
|
||||
checkpoint_model_file = os.path.join(args.output_dir, args.checkpoint_model)
|
||||
if epoch == -1 and args.load_checkpoint and os.path.exists(checkpoint_model_file):
|
||||
checkpoint_state = torch.load(checkpoint_model_file, map_location='cpu')
|
||||
epoch = checkpoint_state['epoch']
|
||||
global_step = checkpoint_state['global_step']
|
||||
model_state_dict = checkpoint_state['model_state_dict']
|
||||
last_optim_state = checkpoint_state['last_optimizer_state']
|
||||
if args.fp16 and 'amp_state_dict' in checkpoint_state:
|
||||
amp_state_dict = checkpoint_state['amp_state_dict']
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
||||
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
|
||||
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
||||
|
||||
model.to(device)
|
||||
logger.info("Checkpoint loaded from %s", checkpoint_model_file)
|
||||
elif os.path.exists(model_file):
|
||||
model_state_dict = torch.load(model_file, map_location='cpu')
|
||||
logger.info("Model loaded from %s", model_file)
|
||||
# Prepare model
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
||||
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
|
||||
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
||||
|
||||
model.to(device)
|
||||
|
||||
return epoch, global_step, last_optim_state, amp_state_dict, model
|
||||
|
||||
def train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer, scheduler,
|
||||
global_step, nlgEvalObj=None, local_rank=0):
|
||||
global logger
|
||||
torch.cuda.empty_cache()
|
||||
model.train()
|
||||
log_step = args.n_display
|
||||
start_time = time.time()
|
||||
total_loss = 0
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# if n_gpu == 1:
|
||||
# # multi-gpu does scattering it-self
|
||||
# batch = tuple(t.to(device) for t in batch)
|
||||
batch = tuple(t.to(device=device, non_blocking=True) for t in batch)
|
||||
|
||||
input_ids, input_mask, segment_ids, video, video_mask, \
|
||||
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index,\
|
||||
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids = batch
|
||||
|
||||
loss = model(input_ids, segment_ids, input_mask, video, video_mask,
|
||||
pairs_masked_text=pairs_masked_text, pairs_token_labels=pairs_token_labels,
|
||||
masked_video=masked_video, video_labels_index=video_labels_index,
|
||||
input_caption_ids=pairs_input_caption_ids, decoder_mask=pairs_decoder_mask,
|
||||
output_caption_ids=pairs_output_caption_ids)
|
||||
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
total_loss += float(loss)
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step() # Update learning rate schedule
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
global_step += 1
|
||||
if global_step % log_step == 0 and local_rank == 0:
|
||||
logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
|
||||
args.epochs, step + 1,
|
||||
len(train_dataloader), "-".join([str('%.6f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]),
|
||||
float(loss) * args.gradient_accumulation_steps,
|
||||
(time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
|
||||
start_time = time.time()
|
||||
|
||||
total_loss = total_loss / len(train_dataloader)
|
||||
return total_loss, global_step
|
||||
|
||||
|
||||
# ---------------------------------------->
|
||||
def get_inst_idx_to_tensor_position_map(inst_idx_list):
|
||||
''' Indicate the position of an instance in a tensor. '''
|
||||
return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}
|
||||
|
||||
|
||||
def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
|
||||
''' Collect tensor parts associated to active instances. '''
|
||||
|
||||
_, *d_hs = beamed_tensor.size()
|
||||
n_curr_active_inst = len(curr_active_inst_idx)
|
||||
new_shape = (n_curr_active_inst * n_bm, *d_hs)
|
||||
|
||||
beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
|
||||
beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
|
||||
beamed_tensor = beamed_tensor.view(*new_shape)
|
||||
|
||||
return beamed_tensor
|
||||
|
||||
|
||||
def collate_active_info(input_tuples, inst_idx_to_position_map, active_inst_idx_list, n_bm, device):
|
||||
assert isinstance(input_tuples, tuple)
|
||||
sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt = input_tuples
|
||||
|
||||
# Sentences which are still active are collected,
|
||||
# so the decoder will not run on completed sentences.
|
||||
n_prev_active_inst = len(inst_idx_to_position_map)
|
||||
active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
|
||||
active_inst_idx = torch.LongTensor(active_inst_idx).to(device)
|
||||
|
||||
active_sequence_output_rpt = collect_active_part(sequence_output_rpt, active_inst_idx, n_prev_active_inst, n_bm)
|
||||
active_visual_output_rpt = collect_active_part(visual_output_rpt, active_inst_idx, n_prev_active_inst, n_bm)
|
||||
active_input_ids_rpt = collect_active_part(input_ids_rpt, active_inst_idx, n_prev_active_inst, n_bm)
|
||||
active_input_mask_rpt = collect_active_part(input_mask_rpt, active_inst_idx, n_prev_active_inst, n_bm)
|
||||
active_video_mask_rpt = collect_active_part(video_mask_rpt, active_inst_idx, n_prev_active_inst, n_bm)
|
||||
active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
|
||||
|
||||
return (active_sequence_output_rpt, active_visual_output_rpt, active_input_ids_rpt, active_input_mask_rpt, active_video_mask_rpt), \
|
||||
active_inst_idx_to_position_map
|
||||
|
||||
def beam_decode_step(decoder, inst_dec_beams, len_dec_seq,
|
||||
inst_idx_to_position_map, n_bm, device, input_tuples, decoder_length=None):
|
||||
|
||||
assert isinstance(input_tuples, tuple)
|
||||
|
||||
''' Decode and update beam status, and then return active beam idx'''
|
||||
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
|
||||
dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
|
||||
dec_partial_seq = torch.stack(dec_partial_seq).to(device)
|
||||
dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
|
||||
return dec_partial_seq
|
||||
|
||||
def predict_word(next_decoder_ids, n_active_inst, n_bm, device, input_tuples):
|
||||
sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt = input_tuples
|
||||
next_decoder_mask = torch.ones(next_decoder_ids.size(), dtype=torch.uint8).to(device)
|
||||
|
||||
dec_output = decoder(sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt,
|
||||
video_mask_rpt, next_decoder_ids, next_decoder_mask, shaped=True, get_logits=True)
|
||||
dec_output = dec_output[:, -1, :]
|
||||
word_prob = torch.nn.functional.log_softmax(dec_output, dim=1)
|
||||
word_prob = word_prob.view(n_active_inst, n_bm, -1)
|
||||
return word_prob
|
||||
|
||||
def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map, decoder_length=None):
|
||||
active_inst_idx_list = []
|
||||
for inst_idx, inst_position in inst_idx_to_position_map.items():
|
||||
if decoder_length is None:
|
||||
is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
|
||||
else:
|
||||
is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position], word_length=decoder_length[inst_idx])
|
||||
if not is_inst_complete:
|
||||
active_inst_idx_list += [inst_idx]
|
||||
|
||||
return active_inst_idx_list
|
||||
|
||||
n_active_inst = len(inst_idx_to_position_map)
|
||||
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
|
||||
word_prob = predict_word(dec_seq, n_active_inst, n_bm, device, input_tuples)
|
||||
|
||||
# Update the beam with predicted word prob information and collect incomplete instances
|
||||
active_inst_idx_list = collect_active_inst_idx_list(inst_dec_beams, word_prob, inst_idx_to_position_map,
|
||||
decoder_length=decoder_length)
|
||||
|
||||
return active_inst_idx_list
|
||||
|
||||
def collect_hypothesis_and_scores(inst_dec_beams, n_best):
|
||||
all_hyp, all_scores = [], []
|
||||
for inst_idx in range(len(inst_dec_beams)):
|
||||
scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
|
||||
all_scores += [scores[:n_best]]
|
||||
|
||||
hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
|
||||
all_hyp += [hyps]
|
||||
return all_hyp, all_scores
|
||||
# >----------------------------------------
|
||||
|
||||
def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=None, nlgEvalObj=None, test_set=None):
|
||||
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module.to(device)
|
||||
|
||||
if model._choice_sim:
|
||||
return 0.
|
||||
|
||||
all_result_lists = []
|
||||
all_caption_lists = []
|
||||
model.eval()
|
||||
for batch in test_dataloader:
|
||||
batch = tuple(t.to(device, non_blocking=True) for t in batch)
|
||||
|
||||
input_ids, input_mask, segment_ids, video, video_mask, \
|
||||
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index, \
|
||||
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids = batch
|
||||
|
||||
with torch.no_grad():
|
||||
sequence_output, visual_output = model.get_sequence_visual_output(input_ids, segment_ids, input_mask, video, video_mask)
|
||||
# -- Repeat data for beam search
|
||||
n_bm = 5 # beam_size
|
||||
device = sequence_output.device
|
||||
n_inst, len_s, d_h = sequence_output.size()
|
||||
_, len_v, v_h = visual_output.size()
|
||||
|
||||
decoder = model.decoder_caption # This is a decoder function
|
||||
|
||||
# Note: shaped first, then decoder need the parameter shaped=True
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1]) # [batch_size*n_pair, self.max_words]
|
||||
input_mask = input_mask.view(-1, input_mask.shape[-1]) # [batch_size*n_pair, self.max_words]
|
||||
video_mask = video_mask.view(-1, video_mask.shape[-1]) # [batch_size*n_pair, self.max_frames]
|
||||
|
||||
sequence_output_rpt = sequence_output.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)
|
||||
visual_output_rpt = visual_output.repeat(1, n_bm, 1).view(n_inst * n_bm, len_v, v_h)
|
||||
input_ids_rpt = input_ids.repeat(1, n_bm).view(n_inst * n_bm, len_s)
|
||||
input_mask_rpt = input_mask.repeat(1, n_bm).view(n_inst * n_bm, len_s)
|
||||
video_mask_rpt = video_mask.repeat(1, n_bm).view(n_inst * n_bm, len_v)
|
||||
|
||||
# -- Prepare beams
|
||||
inst_dec_beams = [Beam(n_bm, device=device, tokenizer=tokenizer) for _ in range(n_inst)]
|
||||
# -- Bookkeeping for active or not
|
||||
active_inst_idx_list = list(range(n_inst))
|
||||
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
|
||||
# -- Decode
|
||||
for len_dec_seq in range(1, args.max_words + 1):
|
||||
active_inst_idx_list = beam_decode_step(decoder, inst_dec_beams,
|
||||
len_dec_seq, inst_idx_to_position_map, n_bm, device,
|
||||
(sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt))
|
||||
|
||||
if not active_inst_idx_list:
|
||||
break # all instances have finished their path to <EOS>
|
||||
|
||||
(sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt), \
|
||||
inst_idx_to_position_map = collate_active_info((sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt),
|
||||
inst_idx_to_position_map, active_inst_idx_list, n_bm, device)
|
||||
|
||||
batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, 1)
|
||||
result_list = [batch_hyp[i][0] for i in range(n_inst)]
|
||||
|
||||
pairs_output_caption_ids = pairs_output_caption_ids.view(-1, pairs_output_caption_ids.shape[-1])
|
||||
caption_list = pairs_output_caption_ids.cpu().detach().numpy()
|
||||
|
||||
for re_idx, re_list in enumerate(result_list):
|
||||
decode_text_list = tokenizer.convert_ids_to_tokens(re_list)
|
||||
if "[SEP]" in decode_text_list:
|
||||
SEP_index = decode_text_list.index("[SEP]")
|
||||
decode_text_list = decode_text_list[:SEP_index]
|
||||
if "[PAD]" in decode_text_list:
|
||||
PAD_index = decode_text_list.index("[PAD]")
|
||||
decode_text_list = decode_text_list[:PAD_index]
|
||||
decode_text = ' '.join(decode_text_list)
|
||||
decode_text = decode_text.replace(" ##", "").strip("##").strip()
|
||||
all_result_lists.append(decode_text)
|
||||
|
||||
for re_idx, re_list in enumerate(caption_list):
|
||||
decode_text_list = tokenizer.convert_ids_to_tokens(re_list)
|
||||
if "[SEP]" in decode_text_list:
|
||||
SEP_index = decode_text_list.index("[SEP]")
|
||||
decode_text_list = decode_text_list[:SEP_index]
|
||||
if "[PAD]" in decode_text_list:
|
||||
PAD_index = decode_text_list.index("[PAD]")
|
||||
decode_text_list = decode_text_list[:PAD_index]
|
||||
decode_text = ' '.join(decode_text_list)
|
||||
decode_text = decode_text.replace(" ##", "").strip("##").strip()
|
||||
all_caption_lists.append(decode_text)
|
||||
|
||||
# Save full results
|
||||
if test_set is not None and hasattr(test_set, 'iter2video_pairs_dict'):
|
||||
hyp_path = os.path.join(args.output_dir, "hyp_complete_results.txt")
|
||||
with open(hyp_path, "w", encoding='utf-8') as writer:
|
||||
writer.write("{}\t{}\t{}\n".format("video_id", "start_time", "caption"))
|
||||
for idx, pre_txt in enumerate(all_result_lists):
|
||||
video_id, sub_id = test_set.iter2video_pairs_dict[idx]
|
||||
start_time = test_set.caption[video_id]['start'][sub_id]
|
||||
writer.write("{}\t{}\t{}\n".format(video_id, start_time, pre_txt))
|
||||
logger.info("File of complete results is saved in {}".format(hyp_path))
|
||||
|
||||
# Save pure results
|
||||
hyp_path = os.path.join(args.output_dir, "hyp.txt")
|
||||
with open(hyp_path, "w", encoding='utf-8') as writer:
|
||||
for pre_txt in all_result_lists:
|
||||
writer.write(pre_txt+"\n")
|
||||
|
||||
ref_path = os.path.join(args.output_dir, "ref.txt")
|
||||
with open(ref_path, "w", encoding='utf-8') as writer:
|
||||
for ground_txt in all_caption_lists:
|
||||
writer.write(ground_txt + "\n")
|
||||
|
||||
# Filter out hyps of 0 length
|
||||
hyps_and_refs = zip(all_result_lists, all_caption_lists)
|
||||
hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0
|
||||
and len([" ".join(sub_s.split()) for sub_s in _[0].split(".") if len(sub_s) > 0]) > 0]
|
||||
all_result_lists, all_caption_lists = zip(*hyps_and_refs)
|
||||
|
||||
# Evaluate
|
||||
metrics_dict = rougeObj.get_scores(hyps=all_result_lists, refs=all_caption_lists, avg=True, ignore_empty=True)
|
||||
logger.info(">>> rouge_1f: {:.4f}, rouge_2f: {:.4f}, rouge_lf: {:.4f}".
|
||||
format(metrics_dict["rouge-1"]["f"], metrics_dict["rouge-2"]["f"], metrics_dict["rouge-l"]["f"]))
|
||||
|
||||
metrics_nlg = nlgEvalObj.compute_metrics(ref_list=[all_caption_lists], hyp_list=all_result_lists)
|
||||
logger.info(">>> BLEU_1: {:.4f}, BLEU_2: {:.4f}, BLEU_3: {:.4f}, BLEU_4: {:.4f}".
|
||||
format(metrics_nlg["Bleu_1"], metrics_nlg["Bleu_2"], metrics_nlg["Bleu_3"], metrics_nlg["Bleu_4"]))
|
||||
logger.info(">>> METEOR: {:.4f}, ROUGE_L: {:.4f}, CIDEr: {:.4f}".format(metrics_nlg["METEOR"], metrics_nlg["ROUGE_L"], metrics_nlg["CIDEr"]))
|
||||
|
||||
Bleu_4 = metrics_nlg["Bleu_4"]
|
||||
return Bleu_4
|
||||
|
||||
DATALOADER_DICT = {}
|
||||
DATALOADER_DICT["youcook"] = {"train":dataloader_youcook_train, "val":dataloader_youcook_test}
|
||||
|
||||
def main():
|
||||
global logger
|
||||
args = get_args()
|
||||
args, world_size, local_rank = set_seed_logger(args)
|
||||
device, n_gpu = init_device(args, local_rank)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
if args.do_train:
|
||||
args.task_type = "caption"
|
||||
model = init_model(args, device, n_gpu, local_rank)
|
||||
only_sim = model.module._choice_sim if hasattr(model, 'module') else model._choice_sim
|
||||
|
||||
if args.task_type == "caption":
|
||||
rougeObj = Rouge()
|
||||
nlgEvalObj = NLGEval(no_overlap=False, no_skipthoughts=True, no_glove=True, metrics_to_omit=None)
|
||||
|
||||
datatype = "youcook"
|
||||
assert datatype in DATALOADER_DICT
|
||||
if args.do_pretrain is False:
|
||||
test_dataloader, test_length = DATALOADER_DICT[datatype]["val"](args, tokenizer)
|
||||
if local_rank == 0:
|
||||
logger.info("***** Running test *****")
|
||||
logger.info(" Num examples = %d", test_length)
|
||||
logger.info(" Batch size = %d", args.batch_size_val)
|
||||
logger.info(" Num steps = %d", len(test_dataloader))
|
||||
|
||||
if args.do_pretrain:
|
||||
train_dataloader, train_length, sampler = dataloader_pretrain(args, tokenizer, only_sim=only_sim)
|
||||
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
|
||||
/ args.gradient_accumulation_steps) * args.epochs
|
||||
|
||||
global_step = 0
|
||||
epoch = -1
|
||||
last_optim_state = None
|
||||
amp_state_dict = None
|
||||
if args.load_checkpoint:
|
||||
epoch, global_step, last_optim_state, amp_state_dict, model = load_model(epoch, args, n_gpu, device, model, global_step=global_step)
|
||||
epoch += 1
|
||||
if local_rank == 0:
|
||||
logger.warning("Will continue to epoch: {}".format(epoch))
|
||||
epoch = 0 if epoch < 0 else epoch
|
||||
|
||||
coef_lr = args.coef_lr
|
||||
if args.init_model:
|
||||
coef_lr = 1.0
|
||||
|
||||
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=coef_lr)
|
||||
if last_optim_state is not None:
|
||||
optimizer.load_state_dict(last_optim_state)
|
||||
if amp_state_dict is not None:
|
||||
amp.load_state_dict(amp_state_dict)
|
||||
|
||||
if local_rank == 0:
|
||||
logger.info("***** Running pretraining *****")
|
||||
logger.info(" Num examples = %d", train_length)
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
|
||||
|
||||
iter_ls_ = [itm for itm in range(args.epochs) if itm >= epoch]
|
||||
for epoch in iter_ls_:
|
||||
sampler.set_epoch(epoch)
|
||||
|
||||
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer,
|
||||
scheduler, global_step, nlgEvalObj=nlgEvalObj, local_rank=local_rank)
|
||||
|
||||
if local_rank == 0:
|
||||
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
|
||||
save_model(epoch, args, model, local_rank, type_name="pretrain", global_step=global_step, optimizer=optimizer)
|
||||
elif args.do_train:
|
||||
|
||||
train_dataloader, train_length, train_sampler = DATALOADER_DICT[datatype]["train"](args, tokenizer)
|
||||
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
|
||||
/ args.gradient_accumulation_steps) * args.epochs
|
||||
|
||||
coef_lr = args.coef_lr
|
||||
if args.init_model:
|
||||
coef_lr = 1.0
|
||||
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=coef_lr)
|
||||
|
||||
if local_rank == 0:
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", train_length)
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
|
||||
|
||||
best_score = 0.00001
|
||||
best_output_model_file = None
|
||||
global_step = 0
|
||||
for epoch in range(args.epochs):
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer,
|
||||
scheduler, global_step, nlgEvalObj=nlgEvalObj, local_rank=local_rank)
|
||||
|
||||
if local_rank == 0:
|
||||
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
|
||||
output_model_file = save_model(epoch, args, model, local_rank, type_name="")
|
||||
if epoch > 0:
|
||||
Bleu_4 = eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=rougeObj, nlgEvalObj=nlgEvalObj)
|
||||
if best_score <= Bleu_4:
|
||||
best_score = Bleu_4
|
||||
best_output_model_file = output_model_file
|
||||
logger.info("The best model is: {}, the Bleu_4 is: {:.4f}".format(best_output_model_file, best_score))
|
||||
else:
|
||||
logger.warning("Skip the evaluation after {}-th epoch.".format(epoch+1))
|
||||
|
||||
if local_rank == 0:
|
||||
_, _, _, _, model = load_model(-1, args, n_gpu, device, model, model_file=best_output_model_file)
|
||||
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=rougeObj, nlgEvalObj=nlgEvalObj)
|
||||
elif args.do_eval:
|
||||
if local_rank == 0:
|
||||
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=rougeObj, nlgEvalObj=nlgEvalObj)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,59 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import threading
|
||||
from torch._utils import ExceptionWrapper
|
||||
|
||||
def get_a_var(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj
|
||||
|
||||
if isinstance(obj, list) or isinstance(obj, tuple):
|
||||
for result in map(get_a_var, obj):
|
||||
if isinstance(result, torch.Tensor):
|
||||
return result
|
||||
if isinstance(obj, dict):
|
||||
for result in map(get_a_var, obj.items()):
|
||||
if isinstance(result, torch.Tensor):
|
||||
return result
|
||||
return None
|
||||
|
||||
def parallel_apply(fct, model, inputs, device_ids):
|
||||
modules = nn.parallel.replicate(model, device_ids)
|
||||
assert len(modules) == len(inputs)
|
||||
lock = threading.Lock()
|
||||
results = {}
|
||||
grad_enabled = torch.is_grad_enabled()
|
||||
|
||||
def _worker(i, module, input):
|
||||
torch.set_grad_enabled(grad_enabled)
|
||||
device = get_a_var(input).get_device()
|
||||
try:
|
||||
with torch.cuda.device(device):
|
||||
# this also avoids accidental slicing of `input` if it is a Tensor
|
||||
if not isinstance(input, (list, tuple)):
|
||||
input = (input,)
|
||||
output = fct(module, *input)
|
||||
with lock:
|
||||
results[i] = output
|
||||
except Exception:
|
||||
with lock:
|
||||
results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device))
|
||||
|
||||
if len(modules) > 1:
|
||||
threads = [threading.Thread(target=_worker, args=(i, module, input))
|
||||
for i, (module, input) in enumerate(zip(modules, inputs))]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
else:
|
||||
_worker(0, modules[0], inputs[0])
|
||||
|
||||
outputs = []
|
||||
for i in range(len(inputs)):
|
||||
output = results[i]
|
||||
if isinstance(output, ExceptionWrapper):
|
||||
output.reraise()
|
||||
outputs.append(output)
|
||||
return outputs
|
|
@ -0,0 +1,210 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import print_function
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import pandas as pd
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
class Youcook_NoPair_DataLoader(Dataset):
|
||||
"""Youcook dataset loader."""
|
||||
def __init__(
|
||||
self,
|
||||
csv,
|
||||
features_path,
|
||||
caption,
|
||||
tokenizer,
|
||||
min_time=10.0,
|
||||
features_path_3D=None,
|
||||
feature_framerate=1.0,
|
||||
feature_framerate_3D=24.0 / 16.0,
|
||||
max_words=30,
|
||||
min_words=0,
|
||||
n_pair=-1, # -1 for test
|
||||
pad_token="[PAD]",
|
||||
max_frames=100,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
"""
|
||||
self.csv = pd.read_csv(csv)
|
||||
self.features_path_2D = features_path
|
||||
self.features_path_3D = features_path_3D
|
||||
self.caption = caption
|
||||
self.min_time = min_time
|
||||
self.feature_framerate = feature_framerate
|
||||
self.feature_framerate_3D = feature_framerate_3D
|
||||
self.max_words = max_words
|
||||
self.max_frames = max_frames
|
||||
self.min_words = min_words
|
||||
self.tokenizer = tokenizer
|
||||
self.n_pair = n_pair
|
||||
self.pad_token = pad_token
|
||||
self.fps = {'2d': feature_framerate, '3d': feature_framerate_3D}
|
||||
self.feature_path = {'2d': features_path}
|
||||
if features_path_3D != '':
|
||||
self.feature_path['3d'] = features_path_3D
|
||||
|
||||
# Get iterator video ids
|
||||
video_id_list = [itm for itm in self.csv['video_id'].values]
|
||||
self.video_id2idx_dict = {video_id: id for id, video_id in enumerate(video_id_list)}
|
||||
# Get all captions
|
||||
self.iter2video_pairs_dict = {}
|
||||
iter_idx_ = 0
|
||||
for video_id in video_id_list:
|
||||
caption = self.caption[video_id]
|
||||
n_caption = len(caption['start'])
|
||||
for sub_id in range(n_caption):
|
||||
self.iter2video_pairs_dict[iter_idx_] = (video_id, sub_id)
|
||||
iter_idx_ += 1
|
||||
|
||||
def __len__(self):
|
||||
return len(self.iter2video_pairs_dict)
|
||||
|
||||
def _get_text(self, video_id, sub_id):
|
||||
caption = self.caption[video_id]
|
||||
k = 1
|
||||
r_ind = [sub_id]
|
||||
|
||||
starts = np.zeros(k)
|
||||
ends = np.zeros(k)
|
||||
pairs_text = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_masked_text = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_token_labels = np.zeros((k, self.max_words), dtype=np.long)
|
||||
|
||||
for i in range(k):
|
||||
ind = r_ind[i]
|
||||
# Note: n_pair_max=-1 means eval
|
||||
words = self.tokenizer.tokenize(caption['text'][ind])
|
||||
start_, end_ = caption['start'][ind], caption['end'][ind]
|
||||
starts[i], ends[i] = start_, end_
|
||||
|
||||
words = ["[CLS]"] + words
|
||||
total_length_with_CLS = self.max_words - 1
|
||||
if len(words) > total_length_with_CLS:
|
||||
words = words[:total_length_with_CLS]
|
||||
words = words + ["[SEP]"]
|
||||
|
||||
# Mask Language Model <-----
|
||||
token_labels = []
|
||||
masked_tokens = words.copy()
|
||||
for token_id, token in enumerate(masked_tokens):
|
||||
if token_id == 0 or token_id == len(masked_tokens) - 1:
|
||||
token_labels.append(-1)
|
||||
continue
|
||||
prob = random.random()
|
||||
# mask token with 15% probability
|
||||
if prob < 0.15:
|
||||
prob /= 0.15
|
||||
|
||||
# 80% randomly change token to mask token
|
||||
if prob < 0.8:
|
||||
masked_tokens[token_id] = "[MASK]"
|
||||
|
||||
# 10% randomly change token to random token
|
||||
elif prob < 0.9:
|
||||
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
|
||||
|
||||
# -> rest 10% randomly keep current token
|
||||
|
||||
# append current token to output (we will predict these later)
|
||||
try:
|
||||
token_labels.append(self.tokenizer.vocab[token])
|
||||
except KeyError:
|
||||
# For unknown words (should not occur with BPE vocab)
|
||||
token_labels.append(self.tokenizer.vocab["[UNK]"])
|
||||
# print("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
|
||||
else:
|
||||
# no masking token (will be ignored by loss function later)
|
||||
token_labels.append(-1)
|
||||
# -----> Mask Language Model
|
||||
|
||||
input_ids = self.tokenizer.convert_tokens_to_ids(words)
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = [0] * len(input_ids)
|
||||
masked_token_ids = self.tokenizer.convert_tokens_to_ids(masked_tokens)
|
||||
while len(input_ids) < self.max_words:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
masked_token_ids.append(0)
|
||||
token_labels.append(-1)
|
||||
assert len(input_ids) == self.max_words
|
||||
assert len(input_mask) == self.max_words
|
||||
assert len(segment_ids) == self.max_words
|
||||
assert len(masked_token_ids) == self.max_words
|
||||
assert len(token_labels) == self.max_words
|
||||
|
||||
pairs_text[i] = np.array(input_ids)
|
||||
pairs_mask[i] = np.array(input_mask)
|
||||
pairs_segment[i] = np.array(segment_ids)
|
||||
pairs_masked_text[i] = np.array(masked_token_ids)
|
||||
pairs_token_labels[i] = np.array(token_labels)
|
||||
|
||||
return pairs_text, pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels, starts, ends
|
||||
|
||||
def _get_video(self, idx, s, e):
|
||||
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
|
||||
max_video_length = [0] * len(s)
|
||||
f_title = "feature_file_2D"
|
||||
fps_k = "2d"
|
||||
|
||||
feature_file = os.path.join(self.feature_path[fps_k], self.csv[f_title].values[idx])
|
||||
video_features = np.load(feature_file)
|
||||
|
||||
video = np.zeros((len(s), self.max_frames, video_features.shape[-1]), dtype=np.float)
|
||||
for i in range(len(s)):
|
||||
start = int(s[i] * self.fps[fps_k])
|
||||
end = int(e[i] * self.fps[fps_k]) + 1
|
||||
video_slice = video_features[start:end]
|
||||
|
||||
if self.max_frames < video_slice.shape[0]:
|
||||
video_slice = video_slice[:self.max_frames]
|
||||
|
||||
slice_shape = video_slice.shape
|
||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
|
||||
if len(video_slice) < 1:
|
||||
print("video_id: {}, start: {}, end: {}".format(feature_file, start, end))
|
||||
else:
|
||||
video[i][:slice_shape[0]] = video_slice
|
||||
|
||||
for i, v_length in enumerate(max_video_length):
|
||||
video_mask[i][:v_length] = [1] * v_length
|
||||
|
||||
# Mask Frame Model <-----
|
||||
video_labels_index = [[] for _ in range(len(s))]
|
||||
masked_video = video.copy()
|
||||
for i, video_pair_ in enumerate(masked_video):
|
||||
for j, _ in enumerate(video_pair_):
|
||||
if j < max_video_length[i]:
|
||||
prob = random.random()
|
||||
# mask token with 15% probability
|
||||
if prob < 0.15:
|
||||
masked_video[i][j] = [0.] * video.shape[-1]
|
||||
video_labels_index[i].append(j)
|
||||
else:
|
||||
video_labels_index[i].append(-1)
|
||||
else:
|
||||
video_labels_index[i].append(-1)
|
||||
video_labels_index = np.array(video_labels_index, dtype=np.long)
|
||||
# -----> Mask Frame Model
|
||||
|
||||
return video, video_mask, masked_video, video_labels_index
|
||||
|
||||
def __getitem__(self, feature_idx):
|
||||
|
||||
video_id, sub_id = self.iter2video_pairs_dict[feature_idx]
|
||||
idx = self.video_id2idx_dict[video_id]
|
||||
|
||||
pairs_text, pairs_mask, pairs_segment, \
|
||||
pairs_masked_text, pairs_token_labels, starts, ends = self._get_text(video_id, sub_id)
|
||||
|
||||
video, video_mask, masked_video, video_labels_index = self._get_video(idx, starts, ends)
|
||||
|
||||
return pairs_text, pairs_mask, pairs_segment, video, video_mask, \
|
||||
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index
|
|
@ -0,0 +1,246 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import print_function
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import pandas as pd
|
||||
import os
|
||||
import numpy as np
|
||||
import re
|
||||
import random
|
||||
import io
|
||||
|
||||
class Youcook_Transcript_NoPair_DataLoader(Dataset):
|
||||
"""Youcook dataset loader."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv,
|
||||
features_path,
|
||||
caption,
|
||||
tokenizer,
|
||||
min_time=10.0,
|
||||
features_path_3D=None,
|
||||
feature_framerate=1.0,
|
||||
feature_framerate_3D=24.0 / 16.0,
|
||||
max_words=30,
|
||||
min_words=0,
|
||||
n_pair=-1, # -1 for test
|
||||
pad_token="[PAD]",
|
||||
max_frames=100,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
"""
|
||||
self.csv = pd.read_csv(csv)
|
||||
self.features_path_2D = features_path
|
||||
self.features_path_3D = features_path_3D
|
||||
self.caption = caption
|
||||
self.min_time = min_time
|
||||
self.feature_framerate = feature_framerate
|
||||
self.feature_framerate_3D = feature_framerate_3D
|
||||
self.max_words = max_words
|
||||
self.max_frames = max_frames
|
||||
self.min_words = min_words
|
||||
self.tokenizer = tokenizer
|
||||
self.n_pair = n_pair
|
||||
self.pad_token = pad_token
|
||||
self.fps = {'2d': feature_framerate, '3d': feature_framerate_3D}
|
||||
self.feature_path = {'2d': features_path}
|
||||
if features_path_3D != '':
|
||||
self.feature_path['3d'] = features_path_3D
|
||||
|
||||
_feature_file = os.path.join(features_path, self.csv["feature_file_2D"].values[0])
|
||||
self.feature_size = np.load(_feature_file).shape[-1]
|
||||
|
||||
# Get iterator video ids
|
||||
video_id_list = [itm for itm in self.csv['video_id'].values]
|
||||
self.video_id2idx_dict = {video_id: id for id, video_id in enumerate(video_id_list)}
|
||||
# Get all captions
|
||||
self.iter2video_pairs_dict = {}
|
||||
iter_idx_ = 0
|
||||
for video_id in video_id_list:
|
||||
caption = self.caption[video_id]
|
||||
n_caption = len(caption['start'])
|
||||
for sub_id in range(n_caption):
|
||||
self.iter2video_pairs_dict[iter_idx_] = (video_id, sub_id)
|
||||
iter_idx_ += 1
|
||||
|
||||
def __len__(self):
|
||||
return len(self.iter2video_pairs_dict)
|
||||
|
||||
def _get_text(self, video_id, sub_id):
|
||||
caption = self.caption[video_id]
|
||||
k = 1
|
||||
r_ind = [sub_id]
|
||||
|
||||
starts = np.zeros(k)
|
||||
ends = np.zeros(k)
|
||||
pairs_text = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_masked_text = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_token_labels = np.zeros((k, self.max_words), dtype=np.long)
|
||||
|
||||
pairs_input_caption_ids = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_output_caption_ids = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_decoder_mask = np.zeros((k, self.max_words), dtype=np.long)
|
||||
|
||||
for i in range(k):
|
||||
ind = r_ind[i]
|
||||
start_, end_ = caption['start'][ind], caption['end'][ind]
|
||||
starts[i], ends[i] = start_, end_
|
||||
total_length_with_CLS = self.max_words - 1
|
||||
words = self.tokenizer.tokenize(caption['transcript'][ind])
|
||||
|
||||
words = ["[CLS]"] + words
|
||||
if len(words) > total_length_with_CLS:
|
||||
words = words[:total_length_with_CLS]
|
||||
words = words + ["[SEP]"]
|
||||
|
||||
# Mask Language Model <-----
|
||||
token_labels = []
|
||||
masked_tokens = words.copy()
|
||||
for token_id, token in enumerate(masked_tokens):
|
||||
if token_id == 0 or token_id == len(masked_tokens) - 1:
|
||||
token_labels.append(-1)
|
||||
continue
|
||||
prob = random.random()
|
||||
# mask token with 15% probability
|
||||
if prob < 0.15:
|
||||
prob /= 0.15
|
||||
|
||||
# 80% randomly change token to mask token
|
||||
if prob < 0.8:
|
||||
masked_tokens[token_id] = "[MASK]"
|
||||
|
||||
# 10% randomly change token to random token
|
||||
elif prob < 0.9:
|
||||
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
|
||||
|
||||
# -> rest 10% randomly keep current token
|
||||
|
||||
# append current token to output (we will predict these later)
|
||||
try:
|
||||
token_labels.append(self.tokenizer.vocab[token])
|
||||
except KeyError:
|
||||
# For unknown words (should not occur with BPE vocab)
|
||||
token_labels.append(self.tokenizer.vocab["[UNK]"])
|
||||
# print("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
|
||||
else:
|
||||
# no masking token (will be ignored by loss function later)
|
||||
token_labels.append(-1)
|
||||
# -----> Mask Language Model
|
||||
|
||||
input_ids = self.tokenizer.convert_tokens_to_ids(words)
|
||||
masked_token_ids = self.tokenizer.convert_tokens_to_ids(masked_tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = [0] * len(input_ids)
|
||||
while len(input_ids) < self.max_words:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
masked_token_ids.append(0)
|
||||
token_labels.append(-1)
|
||||
assert len(input_ids) == self.max_words
|
||||
assert len(input_mask) == self.max_words
|
||||
assert len(segment_ids) == self.max_words
|
||||
assert len(masked_token_ids) == self.max_words
|
||||
assert len(token_labels) == self.max_words
|
||||
|
||||
pairs_text[i] = np.array(input_ids)
|
||||
pairs_mask[i] = np.array(input_mask)
|
||||
pairs_segment[i] = np.array(segment_ids)
|
||||
pairs_masked_text[i] = np.array(masked_token_ids)
|
||||
pairs_token_labels[i] = np.array(token_labels)
|
||||
|
||||
# For generate captions
|
||||
caption_words = self.tokenizer.tokenize(caption['text'][ind])
|
||||
if len(caption_words) > total_length_with_CLS:
|
||||
caption_words = caption_words[:total_length_with_CLS]
|
||||
input_caption_words = ["[CLS]"] + caption_words
|
||||
output_caption_words = caption_words + ["[SEP]"]
|
||||
|
||||
# For generate captions
|
||||
input_caption_ids = self.tokenizer.convert_tokens_to_ids(input_caption_words)
|
||||
output_caption_ids = self.tokenizer.convert_tokens_to_ids(output_caption_words)
|
||||
decoder_mask = [1] * len(input_caption_ids)
|
||||
while len(input_caption_ids) < self.max_words:
|
||||
input_caption_ids.append(0)
|
||||
output_caption_ids.append(0)
|
||||
decoder_mask.append(0)
|
||||
assert len(input_caption_ids) == self.max_words
|
||||
assert len(output_caption_ids) == self.max_words
|
||||
assert len(decoder_mask) == self.max_words
|
||||
|
||||
pairs_input_caption_ids[i] = np.array(input_caption_ids)
|
||||
pairs_output_caption_ids[i] = np.array(output_caption_ids)
|
||||
pairs_decoder_mask[i] = np.array(decoder_mask)
|
||||
|
||||
return pairs_text, pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels,\
|
||||
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids, starts, ends
|
||||
|
||||
def _get_video(self, idx, s, e):
|
||||
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
|
||||
max_video_length = [0] * len(s)
|
||||
f_title = "feature_file_2D"
|
||||
fps_k = "2d"
|
||||
|
||||
feature_file = os.path.join(self.feature_path[fps_k], self.csv[f_title].values[idx])
|
||||
video_features = np.load(feature_file)
|
||||
|
||||
video = np.zeros((len(s), self.max_frames, self.feature_size), dtype=np.float)
|
||||
for i in range(len(s)):
|
||||
start = int(s[i] * self.fps[fps_k])
|
||||
end = int(e[i] * self.fps[fps_k]) + 1
|
||||
video_slice = video_features[start:end]
|
||||
|
||||
if self.max_frames < video_slice.shape[0]:
|
||||
video_slice = video_slice[:self.max_frames]
|
||||
|
||||
slice_shape = video_slice.shape
|
||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
|
||||
if len(video_slice) < 1:
|
||||
print("video_id: {}, start: {}, end: {}".format(feature_file, start, end))
|
||||
# pass
|
||||
else:
|
||||
video[i][:slice_shape[0]] = video_slice
|
||||
|
||||
for i, v_length in enumerate(max_video_length):
|
||||
video_mask[i][:v_length] = [1] * v_length
|
||||
|
||||
# Mask Frame Model <-----
|
||||
video_labels_index = [[] for _ in range(len(s))]
|
||||
masked_video = video.copy()
|
||||
for i, video_pair_ in enumerate(masked_video):
|
||||
for j, _ in enumerate(video_pair_):
|
||||
if j < max_video_length[i]:
|
||||
prob = random.random()
|
||||
# mask token with 15% probability
|
||||
if prob < 0.15:
|
||||
masked_video[i][j] = [0.] * video.shape[-1]
|
||||
video_labels_index[i].append(j)
|
||||
else:
|
||||
video_labels_index[i].append(-1)
|
||||
else:
|
||||
video_labels_index[i].append(-1)
|
||||
video_labels_index = np.array(video_labels_index, dtype=np.long)
|
||||
# -----> Mask Frame Model
|
||||
|
||||
return video, video_mask, masked_video, video_labels_index
|
||||
|
||||
def __getitem__(self, feature_idx):
|
||||
|
||||
video_id, sub_id = self.iter2video_pairs_dict[feature_idx]
|
||||
idx = self.video_id2idx_dict[video_id]
|
||||
|
||||
pairs_text, pairs_mask, pairs_segment, \
|
||||
pairs_masked_text, pairs_token_labels, pairs_input_caption_ids, \
|
||||
pairs_decoder_mask, pairs_output_caption_ids, starts, ends = self._get_text(video_id, sub_id)
|
||||
|
||||
video, video_mask, masked_video, video_labels_index = self._get_video(idx, starts, ends)
|
||||
|
||||
return pairs_text, pairs_mask, pairs_segment, video, video_mask, \
|
||||
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index, \
|
||||
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids
|
|
@ -0,0 +1,419 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import print_function
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import pandas as pd
|
||||
import os
|
||||
import numpy as np
|
||||
import re
|
||||
import random
|
||||
import io
|
||||
|
||||
class Youtube_Transcript_DataLoader(Dataset):
|
||||
"""
|
||||
Youtube dataset loader.
|
||||
Note: Use transcript as caption, for mask decoder pretrain task.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv,
|
||||
features_path,
|
||||
caption,
|
||||
tokenizer,
|
||||
min_time=10.0,
|
||||
features_path_3D=None,
|
||||
feature_framerate=1.0,
|
||||
feature_framerate_3D=24.0 / 16.0,
|
||||
max_words=30,
|
||||
min_words=0,
|
||||
n_pair=-1, # -1 for test
|
||||
pad_token="[PAD]",
|
||||
max_frames=100,
|
||||
with_long_context=True,
|
||||
use_mil=False,
|
||||
only_sim=False, # set automatically from model choice
|
||||
sampled_use_mil=False,
|
||||
pretrain_enhance_vmodal=False,
|
||||
video_dim=1024,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
"""
|
||||
self.csv = pd.read_csv(csv)
|
||||
self.features_path_2D = features_path
|
||||
self.features_path_3D = features_path_3D
|
||||
self.caption = caption
|
||||
self.min_time = min_time
|
||||
self.feature_framerate = feature_framerate
|
||||
self.feature_framerate_3D = feature_framerate_3D
|
||||
self.max_words = max_words
|
||||
self.max_frames = max_frames
|
||||
self.min_words = min_words
|
||||
self.tokenizer = tokenizer
|
||||
self.n_pair = n_pair
|
||||
self.pad_token = pad_token
|
||||
self.fps = {'2d': feature_framerate, '3d': feature_framerate_3D}
|
||||
self.feature_path = {'2d': features_path}
|
||||
if features_path_3D != '':
|
||||
self.feature_path['3d'] = features_path_3D
|
||||
self.with_long_context = with_long_context
|
||||
|
||||
self.feature_size = video_dim
|
||||
|
||||
self.only_sim = only_sim
|
||||
self.pretrain_enhance_vmodal = pretrain_enhance_vmodal
|
||||
self.iter_num = len(self.csv)
|
||||
|
||||
self.use_mil = use_mil
|
||||
self.sampled_use_mil = sampled_use_mil
|
||||
if self.sampled_use_mil: # sample from each video, has a high priority than use_mil.
|
||||
self.use_mil = True
|
||||
|
||||
if self.use_mil:
|
||||
positive_n_pair = self.n_pair
|
||||
# Get iterator video ids
|
||||
video_id_list = [itm for itm in self.csv['video_id'].values]
|
||||
self.video_id2idx_dict = {video_id: id for id, video_id in enumerate(video_id_list)}
|
||||
# Get all captions
|
||||
self.iter2video_pairs_dict = {}
|
||||
self.iter2video_pairslist_dict = {}
|
||||
iter_idx_mil_ = 0
|
||||
for video_id in video_id_list:
|
||||
caption = self.caption[video_id]
|
||||
n_caption = len(caption['start'])
|
||||
|
||||
sub_list = []
|
||||
if self.n_pair < 0 or self.n_pair == 1:
|
||||
for sub_id in range(n_caption):
|
||||
sub_list.append([sub_id])
|
||||
else:
|
||||
sb_ls_ = list(range(n_caption))
|
||||
sb_st_ = set(sb_ls_)
|
||||
if self.n_pair > n_caption:
|
||||
sb_ls_ = sb_ls_ * (self.n_pair // n_caption + 1)
|
||||
sb_ls_ = sb_ls_[:self.n_pair]
|
||||
for sub_id in np.arange(0, len(sb_ls_), self.n_pair):
|
||||
sub_list.append(sb_ls_[sub_id: sub_id + self.n_pair])
|
||||
else:
|
||||
sb_ls_ = sb_ls_ + sb_ls_[:(((n_caption+positive_n_pair-1)//positive_n_pair)*positive_n_pair-n_caption)]
|
||||
for sub_id in np.arange(0, len(sb_ls_), positive_n_pair):
|
||||
pos_ls = sb_ls_[sub_id: sub_id + positive_n_pair]
|
||||
sub_list.append(pos_ls)
|
||||
|
||||
for sub_e in sub_list:
|
||||
self.iter2video_pairs_dict[iter_idx_mil_] = (video_id, sub_e)
|
||||
iter_idx_mil_ += 1
|
||||
self.iter2video_pairslist_dict[video_id] = sub_list
|
||||
|
||||
if self.use_mil and self.sampled_use_mil is False:
|
||||
self.iter_num = len(self.iter2video_pairs_dict)
|
||||
|
||||
def __len__(self):
|
||||
return self.iter_num
|
||||
|
||||
def _mask_tokens(self, words, orig2token_tuple_list=None, chunk_positions=None):
|
||||
token_labels = []
|
||||
masked_tokens = words.copy()
|
||||
|
||||
if chunk_positions is not None and len(chunk_positions)>0:
|
||||
token_labels = [-1] * len(masked_tokens)
|
||||
for chunk_ind in chunk_positions:
|
||||
start_, len_ = orig2token_tuple_list[chunk_ind]
|
||||
if random.random() < 0.15:
|
||||
token_labels[start_:start_+len_] = [self.tokenizer.vocab[tk_] for tk_ in masked_tokens[start_:start_+len_]]
|
||||
masked_tokens[start_:start_+len_] = ["[MASK]"] * len_
|
||||
else:
|
||||
for token_id, token in enumerate(masked_tokens):
|
||||
if token_id == 0 or token_id == len(masked_tokens) - 1:
|
||||
token_labels.append(-1)
|
||||
continue
|
||||
prob = random.random()
|
||||
if prob < 0.15:
|
||||
prob /= 0.15
|
||||
if prob < 0.8:
|
||||
masked_tokens[token_id] = "[MASK]"
|
||||
elif prob < 0.9:
|
||||
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
|
||||
try:
|
||||
token_labels.append(self.tokenizer.vocab[token])
|
||||
except KeyError:
|
||||
token_labels.append(self.tokenizer.vocab["[UNK]"])
|
||||
else:
|
||||
token_labels.append(-1)
|
||||
|
||||
return masked_tokens, token_labels
|
||||
|
||||
def _get_text(self, video_id, n_pair_max, sub_ids=None, only_sim=False, enhance_vmodel=False):
|
||||
caption = self.caption[video_id]
|
||||
|
||||
if self.use_mil:
|
||||
k = len(sub_ids)
|
||||
r_ind = sub_ids
|
||||
else:
|
||||
n_caption = len(caption['start'])
|
||||
if n_pair_max == -1:
|
||||
k = n_caption
|
||||
r_ind = range(n_caption)
|
||||
else:
|
||||
k = n_pair_max
|
||||
if k <= n_caption:
|
||||
r_ind = np.random.choice(range(n_caption), k, replace=False)
|
||||
else:
|
||||
r_ind_must = np.array(range(n_caption))
|
||||
r_ind_rand = np.random.choice(range(n_caption), k-n_caption, replace=True)
|
||||
r_ind = np.concatenate((r_ind_must, r_ind_rand), axis=0)
|
||||
np.random.shuffle(r_ind)
|
||||
|
||||
starts = np.zeros(k)
|
||||
ends = np.zeros(k)
|
||||
pairs_text = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_masked_text = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_token_labels = np.zeros((k, self.max_words), dtype=np.long)
|
||||
|
||||
pairs_input_caption_ids = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_output_caption_ids = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_decoder_mask = np.zeros((k, self.max_words), dtype=np.long)
|
||||
|
||||
for i in range(k):
|
||||
ind = r_ind[i]
|
||||
words, start_, end_ = self._get_single_transcript(caption, ind, with_long_context=self.with_long_context)
|
||||
caption_words = words.copy()
|
||||
starts[i], ends[i] = start_, end_
|
||||
|
||||
orig2token_tuple_list, chunk_positions = None, None
|
||||
# # For entity mask
|
||||
# orig_sentence, orig_to_token_map = generate_org_sentence_from_tokens(words.copy())
|
||||
# four_tuples_ = generate_knoledge_candidate(orig_sentence, orig_to_token_map)
|
||||
# _, orig2token_tuple_list, chunk_positions, _ = four_tuples_
|
||||
|
||||
if enhance_vmodel:
|
||||
words = [] # mask all input text
|
||||
|
||||
words = ["[CLS]"] + words
|
||||
total_length_with_CLS = self.max_words - 1
|
||||
if len(words) > total_length_with_CLS:
|
||||
words = words[:total_length_with_CLS]
|
||||
words = words + ["[SEP]"]
|
||||
|
||||
input_ids = self.tokenizer.convert_tokens_to_ids(words)
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = [0] * len(input_ids)
|
||||
while len(input_ids) < self.max_words:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
assert len(input_ids) == self.max_words
|
||||
assert len(input_mask) == self.max_words
|
||||
assert len(segment_ids) == self.max_words
|
||||
|
||||
pairs_text[i] = np.array(input_ids)
|
||||
pairs_mask[i] = np.array(input_mask)
|
||||
pairs_segment[i] = np.array(segment_ids)
|
||||
|
||||
if only_sim is False:
|
||||
# # For entity mask : +1 due to [CLS]
|
||||
# orig2token_tuple_list = [(s_+1, len_) for s_, len_ in orig2token_tuple_list if s_+len_ <= total_length_with_CLS]
|
||||
# chunk_positions = [ind for ind in chunk_positions if ind<len(orig2token_tuple_list)]
|
||||
|
||||
# For generate captions
|
||||
if len(caption_words) > total_length_with_CLS:
|
||||
caption_words = caption_words[:total_length_with_CLS]
|
||||
input_caption_words = ["[CLS]"] + caption_words
|
||||
output_caption_words = caption_words + ["[SEP]"]
|
||||
|
||||
masked_tokens, token_labels = self._mask_tokens(words, orig2token_tuple_list, chunk_positions)
|
||||
masked_token_ids = self.tokenizer.convert_tokens_to_ids(masked_tokens)
|
||||
masked_input_caption_words, input_token_labels = self._mask_tokens(input_caption_words)
|
||||
input_caption_words = masked_input_caption_words.copy()
|
||||
|
||||
while len(masked_token_ids) < self.max_words:
|
||||
masked_token_ids.append(0)
|
||||
token_labels.append(-1)
|
||||
assert len(masked_token_ids) == self.max_words
|
||||
assert len(token_labels) == self.max_words
|
||||
|
||||
# For generate captions
|
||||
input_caption_ids = self.tokenizer.convert_tokens_to_ids(input_caption_words)
|
||||
output_caption_ids = self.tokenizer.convert_tokens_to_ids(output_caption_words)
|
||||
decoder_mask = [1] * len(input_caption_ids)
|
||||
while len(input_caption_ids) < self.max_words:
|
||||
input_caption_ids.append(0)
|
||||
output_caption_ids.append(0)
|
||||
decoder_mask.append(0)
|
||||
assert len(input_caption_ids) == self.max_words
|
||||
assert len(output_caption_ids) == self.max_words
|
||||
assert len(decoder_mask) == self.max_words
|
||||
|
||||
pairs_masked_text[i] = np.array(masked_token_ids)
|
||||
pairs_token_labels[i] = np.array(token_labels)
|
||||
|
||||
pairs_input_caption_ids[i] = np.array(input_caption_ids)
|
||||
pairs_output_caption_ids[i] = np.array(output_caption_ids)
|
||||
pairs_decoder_mask[i] = np.array(decoder_mask)
|
||||
|
||||
return pairs_text, pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels, \
|
||||
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids, starts, ends
|
||||
|
||||
def _get_single_transcript(self, caption, ind, with_long_context=True):
|
||||
start, end = ind, ind
|
||||
words = self.tokenizer.tokenize(str(caption['text'][ind]))
|
||||
diff = caption['end'][end] - caption['start'][start]
|
||||
while with_long_context and (len(words) < self.min_words or diff < self.min_time):
|
||||
if start > 0 and end < len(caption['end']) - 1:
|
||||
next_words = self.tokenizer.tokenize(str(caption['text'][end + 1]))
|
||||
prev_words = self.tokenizer.tokenize(str(caption['text'][start - 1]))
|
||||
d1 = caption['end'][end + 1] - caption['start'][start]
|
||||
d2 = caption['end'][end] - caption['start'][start - 1]
|
||||
if (self.min_time > 0 and d2 <= d1) or \
|
||||
(self.min_time == 0 and len(next_words) <= len(prev_words)):
|
||||
start -= 1
|
||||
words = prev_words + words
|
||||
else:
|
||||
end += 1
|
||||
words.extend(next_words)
|
||||
elif start > 0:
|
||||
words = self.tokenizer.tokenize(str(caption['text'][start - 1])) + words
|
||||
start -= 1
|
||||
elif end < len(caption['end']) - 1:
|
||||
words.extend(self.tokenizer.tokenize(str(caption['text'][end + 1])))
|
||||
end += 1
|
||||
else:
|
||||
break
|
||||
diff = caption['end'][end] - caption['start'][start]
|
||||
return words, caption['start'][start], caption['end'][end]
|
||||
|
||||
def _expand_video_slice(self, s, e, si, ei, fps, video_features, fps_k):
|
||||
start = int(s[si] * fps)
|
||||
end = int(e[ei] * fps) + 1
|
||||
|
||||
if start > end:
|
||||
start, end = end, start
|
||||
video_slice = video_features[start:end]
|
||||
|
||||
expand_left = True
|
||||
while len(video_slice) < 1:
|
||||
if si==0 and ei==len(s)-1:
|
||||
break
|
||||
if expand_left:
|
||||
expand_left = False
|
||||
si = si-1 if si>0 else si
|
||||
else:
|
||||
expand_left = True
|
||||
ei = ei+1 if ei<len(e)-1 else ei
|
||||
start = int(s[si] * fps)
|
||||
end = int(e[ei] * fps) + 1
|
||||
if start > end:
|
||||
start, end = end, start
|
||||
video_slice = video_features[start:end]
|
||||
|
||||
# # Note: to alignment the features between 2d and 3d, due to the fps is different.
|
||||
# if fps_k == "2d":
|
||||
# indices = sorted([idx for idx in range(len(video_slice)) if idx % 2 == 1]
|
||||
# + [idx for idx in range(len(video_slice))])
|
||||
# video_slice = np.take(video_slice, indices, axis=0)
|
||||
|
||||
if self.max_frames < video_slice.shape[0]:
|
||||
video_slice = video_slice[:self.max_frames]
|
||||
|
||||
return video_slice, start, end
|
||||
|
||||
def _get_video(self, idx, s, e, only_sim=False):
|
||||
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
|
||||
|
||||
max_video_length = [0] * len(s)
|
||||
f_title = "feature_file_2D"
|
||||
fps_k = "2d"
|
||||
|
||||
video = np.zeros((len(s), self.max_frames, self.feature_size), dtype=np.float)
|
||||
feature_file = os.path.join(self.feature_path[fps_k], self.csv[f_title].values[idx])
|
||||
try:
|
||||
video_features = np.load(feature_file)
|
||||
|
||||
for i in range(len(s)):
|
||||
# start = int(s[i] * self.fps[fps_k])
|
||||
# end = int(e[i] * self.fps[fps_k]) + 1
|
||||
# video_slice = video_features[start:end]
|
||||
#
|
||||
# if self.max_frames < video_slice.shape[0]:
|
||||
# video_slice = video_slice[:self.max_frames]
|
||||
if len(video_features) < 1:
|
||||
raise ValueError("{} is empty.".format(feature_file))
|
||||
video_slice, start, end = self._expand_video_slice(s, e, i, i,
|
||||
self.fps[fps_k], video_features, fps_k)
|
||||
slice_shape = video_slice.shape
|
||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
|
||||
if len(video_slice) < 1:
|
||||
pass
|
||||
# video_slice = video_features
|
||||
# if self.max_frames < video_slice.shape[0]:
|
||||
# video_slice = video_slice[:self.max_frames]
|
||||
else:
|
||||
video[i][:slice_shape[0]] = video_slice
|
||||
except Exception as e:
|
||||
print("video_id: {} error.".format(feature_file))
|
||||
|
||||
for i, v_length in enumerate(max_video_length):
|
||||
video_mask[i][:v_length] = [1] * v_length
|
||||
|
||||
# Mask Frame Model <-----
|
||||
video_labels_index = [[] for _ in range(len(s))]
|
||||
masked_video = video.copy()
|
||||
if only_sim is False:
|
||||
for i, video_pair_ in enumerate(masked_video):
|
||||
for j, _ in enumerate(video_pair_):
|
||||
if j < max_video_length[i]:
|
||||
prob = random.random()
|
||||
# mask token with 15% probability
|
||||
if prob < 0.15:
|
||||
masked_video[i][j] = [0.] * video.shape[-1]
|
||||
video_labels_index[i].append(j)
|
||||
else:
|
||||
video_labels_index[i].append(-1)
|
||||
else:
|
||||
video_labels_index[i].append(-1)
|
||||
video_labels_index = np.array(video_labels_index, dtype=np.long)
|
||||
# -----> Mask Frame Model
|
||||
|
||||
return video, video_mask, masked_video, video_labels_index
|
||||
|
||||
def second_to_stamp(self, in_seconds):
|
||||
m, s = divmod(in_seconds, 60)
|
||||
h, m2 = divmod(m, 60)
|
||||
return "%02d:%02d:%02d" % (h, m2, s)
|
||||
|
||||
def __getitem__(self, feature_idx):
|
||||
if self.sampled_use_mil: # sample from each video, has a high priority than use_mil.
|
||||
idx = feature_idx
|
||||
video_id = self.csv['video_id'].values[idx]
|
||||
sub_list = self.iter2video_pairslist_dict[video_id]
|
||||
ranint = np.random.randint(0, len(sub_list))
|
||||
sub_ids = sub_list[ranint]
|
||||
elif self.use_mil:
|
||||
video_id, sub_ids = self.iter2video_pairs_dict[feature_idx]
|
||||
idx = self.video_id2idx_dict[video_id]
|
||||
else:
|
||||
idx = feature_idx
|
||||
video_id = self.csv['video_id'].values[idx]
|
||||
sub_ids = None
|
||||
|
||||
enhance_vmodel = False
|
||||
if self.only_sim is False and self.pretrain_enhance_vmodal:
|
||||
prob = random.random()
|
||||
if prob < 0.15: # mask all text by rate 0.15
|
||||
enhance_vmodel = True
|
||||
|
||||
pairs_text, pairs_mask, pairs_segment, \
|
||||
pairs_masked_text, pairs_token_labels, pairs_input_caption_ids, \
|
||||
pairs_decoder_mask, pairs_output_caption_ids, \
|
||||
starts, ends = self._get_text(video_id, self.n_pair, sub_ids, only_sim=self.only_sim, enhance_vmodel=enhance_vmodel)
|
||||
|
||||
video, video_mask, masked_video, video_labels_index = self._get_video(idx, starts, ends, only_sim=self.only_sim)
|
||||
|
||||
return pairs_text, pairs_mask, pairs_segment, video, video_mask, \
|
||||
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index, \
|
||||
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids
|
Загрузка…
Ссылка в новой задаче