This commit is contained in:
ArrowLuo 2021-01-06 17:42:34 +08:00
Родитель ba6b9ea0f1
Коммит 4308eaff6c
29 изменённых файлов: 2306 добавлений и 2642 удалений

1
.gitignore поставляемый
Просмотреть файл

@ -1 +1,2 @@
.idea
READMEINS.md

172
README.md
Просмотреть файл

@ -1,9 +1,10 @@
The implementation of paper [**UniVL: A Unified Video and Language Pre-Training Model for Multimodal Understanding and Generation**](https://arxiv.org/abs/2002.06353).
# Preliminary
Excute below scripts in main folder firstly.
Excute below scripts in the main folder firstly.
```
cd pytorch_pretrained_bert/bert-base-uncased/
mkdir modules/bert-base-uncased
cd modules/bert-base-uncased/
wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt
mv bert-base-uncased-vocab.txt vocab.txt
wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz
@ -12,81 +13,172 @@ rm bert-base-uncased.tar.gz
cd ../../
```
# Finetune on YoucookII
## Retrieval
# Requirements
- python==3.6.9
- torch==1.7.0+cu92
- tqdm
- boto3
- requests
- pandas
- nlg-eval (Install Java 1.8.0 (or higher) firstly)
```
INIT_MODEL=<from second phase>
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_retrieval_task.py \
--do_train --num_thread_reader=16 \
--epochs=5 --batch_size=32 \
--n_pair=1 --n_display=100 \
--youcook_train_csv data/youcookii_train.csv \
--youcook_val_csv data/youcookii_val.csv \
--youcook_caption_path data/youcookii_caption.pickle \
--youcook_features_path_2D data/youcookii_videos_features \
--output_dir ckpt_youcook_retrieval --bert_model bert-base-uncased 、
--do_lower_case --lr 3e-5 --max_words 48 --max_frames 48 \
--batch_size_val 200 --visual_num_hidden_layers 6 \
--datatype youcook --init_model ${INIT_MODEL}
conda create -n py_univl python=3.6.9 tqdm boto3 requests pandas
conda activate py_univl
pip install torch==1.7.1+cu92
pip install git+https://github.com/Maluuba/nlg-eval.git@master
```
## Caption
# Finetune on YoucookII
## Retrieval
1. Run retrieval task on **YoucookII**
```
INIT_MODEL=<from second phase>
python -m torch.distributed.launch --nproc_per_node=4 train_transcript_distributed.py \
DATATYPE="youcook"
TRAIN_CSV="data/youcookii/youcookii_train.csv"
VAL_CSV="data/youcookii/youcookii_val.csv"
DATA_PATH="data/youcookii/youcookii_data.pickle"
FEATURES_PATH="data/youcookii/youcookii_videos_features.pickle"
INIT_MODEL="weight/univl.pretrained.bin"
OUTPUT_ROOT="ckpts"
python -m torch.distributed.launch --nproc_per_node=4 \
main_task_retrieval.py \
--do_train --num_thread_reader=16 \
--epochs=5 --batch_size=32 \
--n_display=100 \
--train_csv ${TRAIN_CSV} \
--val_csv ${VAL_CSV} \
--data_path ${DATA_PATH} \
--features_path ${FEATURES_PATH} \
--output_dir ${OUTPUT_ROOT}/ckpt_youcook_retrieval --bert_model bert-base-uncased \
--do_lower_case --lr 3e-5 --max_words 48 --max_frames 48 \
--batch_size_val 200 --visual_num_hidden_layers 6 \
--datatype ${DATATYPE} --init_model ${INIT_MODEL}
```
>The results are close to *R@1: 0.2269 - R@5: 0.5245 - R@10: 0.6586 - Median R: 5.0*
2. Run retrieval task on **MSRVTT**
```
DATATYPE="msrvtt"
TRAIN_CSV="data/msrvtt/MSRVTT_train.9k.csv"
VAL_CSV="data/msrvtt/MSRVTT_JSFUSION_test.csv"
DATA_PATH="data/msrvtt/MSRVTT_data.json"
FEATURES_PATH="data/msrvtt/msrvtt_videos_features.pickle"
INIT_MODEL="weight/univl.pretrained.bin"
OUTPUT_ROOT="ckpts"
python -m torch.distributed.launch --nproc_per_node=4 \
main_task_retrieval.py \
--do_train --num_thread_reader=16 \
--epochs=5 --batch_size=128 \
--n_display=100 \
--train_csv ${TRAIN_CSV} \
--val_csv ${VAL_CSV} \
--data_path ${DATA_PATH} \
--features_path ${FEATURES_PATH} \
--output_dir ${OUTPUT_ROOT}/ckpt_msrvtt_retrieval --bert_model bert-base-uncased \
--do_lower_case --lr 5e-5 --max_words 48 --max_frames 48 \
--batch_size_val 200 --visual_num_hidden_layers 6 \
--datatype ${DATATYPE} --expand_msrvtt_sentences --init_model ${INIT_MODEL}
```
>The results are close to *R@1: 0.2720 - R@5: 0.5570 - R@10: 0.6870 - Median R: 4.0*
## Caption
Run caption task on **YoucookII**
```
TRAIN_CSV="data/youcookii/youcookii_train.csv"
VAL_CSV="data/youcookii/youcookii_val.csv"
DATA_PATH="data/youcookii/youcookii_data.pickle"
FEATURES_PATH="data/youcookii/youcookii_videos_features.pickle"
INIT_MODEL="weight/univl.pretrained.bin"
OUTPUT_ROOT="ckpts"
python -m torch.distributed.launch --nproc_per_node=4 \
main_task_caption.py \
--do_train --num_thread_reader=4 \
--epochs=5 --batch_size=16 \
--n_pair=-1 --n_display=100 \
--youcook_train_csv data/youcookii_train.csv \
--youcook_val_csv data/youcookii_val.csv \
--youcook_caption_path data/youcookii_caption_transcript.pickle \
--youcook_features_path_2D data/youcookii_videos_features \
--output_dir ckpt_youcook_caption --bert_model bert-base-uncased \
--n_display=100 \
--train_csv ${TRAIN_CSV} \
--val_csv ${VAL_CSV} \
--data_path ${DATA_PATH} \
--features_path ${FEATURES_PATH} \
--output_dir ${OUTPUT_ROOT}/ckpt_youcook_caption --bert_model bert-base-uncased \
--do_lower_case --lr 3e-5 --max_words 128 --max_frames 96 \
--batch_size_val 64 --visual_num_hidden_layers 6 \
--decoder_num_hidden_layers 3 \
--decoder_num_hidden_layers 3 --stage_two \
--init_model ${INIT_MODEL}
```
>The results are close to
```
BLEU_1: 0.4746, BLEU_2: 0.3355, BLEU_3: 0.2423, BLEU_4: 0.1779
METEOR: 0.2261, ROUGE_L: 0.4697, CIDEr: 1.8631
```
# Pretrain on HowTo100M
## Phase I
## Format of csv
```
video_id,feature_file
Z8xhli297v8,Z8xhli297v8.npy
...
```
## Stage I
```
ROOT_PATH=.
DATA_PATH=${ROOT_PATH}/data
SAVE_PATH=${ROOT_PATH}/models
MODEL_PATH=${ROOT_PATH}/UniVL
python -m torch.distributed.launch --nproc_per_node=8 ${MODEL_PATH}/train_transcript_distributed.py \
python -m torch.distributed.launch --nproc_per_node=8 \
${MODEL_PATH}/main_pretrain.py \
--do_pretrain --num_thread_reader=0 --epochs=50 \
--batch_size=1920 --n_pair=3 --n_display=100 \
--bert_model bert-base-uncased --do_lower_case --lr 1e-4 \
--max_words 48 --max_frames 64 --batch_size_val 344 \
--output_dir ${SAVE_PATH}/pre_trained/pre_s3d_L48_V6_D3_Phase1 \
--features_path_2D ${DATA_PATH}/features \
--output_dir ${SAVE_PATH}/pre_trained/L48_V6_D3_Phase1 \
--features_path ${DATA_PATH}/features \
--train_csv ${DATA_PATH}/HowTo100M.csv \
--caption_path ${DATA_PATH}/caption.pickle \
--data_path ${DATA_PATH}/caption.pickle \
--visual_num_hidden_layers 6 --gradient_accumulation_steps 16 \
--sampled_use_mil --load_checkpoint
```
## Phase II
## Stage II
```
ROOT_PATH=.
DATA_PATH=${ROOT_PATH}/data
SAVE_PATH=${ROOT_PATH}/models
MODEL_PATH=${ROOT_PATH}/UniVL
INIT_MODEL=<from first phase>
python -m torch.distributed.launch --nproc_per_node=8 ${MODEL_PATH}/train_transcript_distributed.py \
INIT_MODEL=<from first stage>
python -m torch.distributed.launch --nproc_per_node=8 \
${MODEL_PATH}/main_pretrain.py \
--do_pretrain --num_thread_reader=0 --epochs=50 \
--batch_size=960 --n_pair=3 --n_display=100 \
--bert_model bert-base-uncased --do_lower_case --lr 1e-4 \
--max_words 48 --max_frames 64 --batch_size_val 344 \
--output_dir ${SAVE_PATH}/pre_trained/pre_s3d_L48_V6_D3_Phase2 \
--features_path_2D ${DATA_PATH}/features \
--output_dir ${SAVE_PATH}/pre_trained/L48_V6_D3_Phase2 \
--features_path ${DATA_PATH}/features \
--train_csv ${DATA_PATH}/HowTo100M.csv \
--caption_path ${DATA_PATH}/caption.pickle \
--data_path ${DATA_PATH}/caption.pickle \
--visual_num_hidden_layers 6 --decoder_num_hidden_layers 3 \
--gradient_accumulation_steps 60 \
--cross_model --pretrain_with_joint_sim --sampled_use_mil \
--pretrain_enhance_vmodal --pretrain_without_decoder \
--stage_two --pretrain_with_joint_sim --sampled_use_mil \
--pretrain_enhance_vmodal \
--load_checkpoint --init_model ${INIT_MODEL}
```
# Citation
If you find UniVL useful in your work, you can cite the following paper:
```
@Article{Luo2020UniVL,
author = {Huaishao Luo and Lei Ji and Botian Shi and Haoyang Huang and Nan Duan and Tianrui Li and Jason Li and Taroon Bharti and Ming Zhou},
title = {UniVL: A Unified Video and Language Pre-Training Model for Multimodal Understanding and Generation},
journal = {arXiv preprint arXiv:2002.06353},
year = {2020},
}
```
# Acknowledgments
Our code is based on [pytorch-transformers v0.4.0](https://github.com/huggingface/transformers/tree/v0.4.0). We thank the authors for their wonderful open-source efforts.

Просмотреть файл

@ -1,9 +0,0 @@
# pip install prefetch_generator
from torch.utils.data import DataLoader
from prefetch_generator import BackgroundGenerator
class DataLoaderX(DataLoader):
def __iter__(self):
# transforms generator into a background-thead generator.
return BackgroundGenerator(super().__iter__(), max_prefetch=1)

Просмотреть файл

@ -7,11 +7,9 @@ from torch.utils.data import Dataset
import pandas as pd
import os
import numpy as np
import re
import random
import io
class Youtube_Transcript_DataLoader(Dataset):
class Youtube_DataLoader(Dataset):
"""
Youtube dataset loader.
Note: Use transcript as caption, for mask decoder pretrain task.
@ -21,16 +19,13 @@ class Youtube_Transcript_DataLoader(Dataset):
self,
csv,
features_path,
caption,
data_dict,
tokenizer,
min_time=10.0,
features_path_3D=None,
feature_framerate=1.0,
feature_framerate_3D=24.0 / 16.0,
max_words=30,
min_words=0,
n_pair=-1, # -1 for test
pad_token="[PAD]",
n_pair=-1,
max_frames=100,
with_long_context=True,
use_mil=False,
@ -43,24 +38,16 @@ class Youtube_Transcript_DataLoader(Dataset):
Args:
"""
self.csv = pd.read_csv(csv)
self.features_path_2D = features_path
self.features_path_3D = features_path_3D
self.caption = caption
self.features_path = features_path
self.data_dict = data_dict
self.min_time = min_time
self.feature_framerate = feature_framerate
self.feature_framerate_3D = feature_framerate_3D
self.max_words = max_words
self.max_frames = max_frames
self.min_words = min_words
self.tokenizer = tokenizer
self.n_pair = n_pair
self.pad_token = pad_token
self.fps = {'2d': feature_framerate, '3d': feature_framerate_3D}
self.feature_path = {'2d': features_path}
if features_path_3D != '':
self.feature_path['3d'] = features_path_3D
self.with_long_context = with_long_context
self.feature_size = video_dim
self.only_sim = only_sim
@ -69,7 +56,7 @@ class Youtube_Transcript_DataLoader(Dataset):
self.use_mil = use_mil
self.sampled_use_mil = sampled_use_mil
if self.sampled_use_mil: # sample from each video, has a high priority than use_mil.
if self.sampled_use_mil: # sample from each video, has a higher priority than use_mil.
self.use_mil = True
if self.use_mil:
@ -82,8 +69,8 @@ class Youtube_Transcript_DataLoader(Dataset):
self.iter2video_pairslist_dict = {}
iter_idx_mil_ = 0
for video_id in video_id_list:
caption = self.caption[video_id]
n_caption = len(caption['start'])
data_dict = self.data_dict[video_id]
n_caption = len(data_dict['start'])
sub_list = []
if self.n_pair < 0 or self.n_pair == 1:
@ -114,46 +101,38 @@ class Youtube_Transcript_DataLoader(Dataset):
def __len__(self):
return self.iter_num
def _mask_tokens(self, words, orig2token_tuple_list=None, chunk_positions=None):
def _mask_tokens(self, words):
token_labels = []
masked_tokens = words.copy()
if chunk_positions is not None and len(chunk_positions)>0:
token_labels = [-1] * len(masked_tokens)
for chunk_ind in chunk_positions:
start_, len_ = orig2token_tuple_list[chunk_ind]
if random.random() < 0.15:
token_labels[start_:start_+len_] = [self.tokenizer.vocab[tk_] for tk_ in masked_tokens[start_:start_+len_]]
masked_tokens[start_:start_+len_] = ["[MASK]"] * len_
else:
for token_id, token in enumerate(masked_tokens):
if token_id == 0 or token_id == len(masked_tokens) - 1:
token_labels.append(-1)
continue
prob = random.random()
if prob < 0.15:
prob /= 0.15
if prob < 0.8:
masked_tokens[token_id] = "[MASK]"
elif prob < 0.9:
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
try:
token_labels.append(self.tokenizer.vocab[token])
except KeyError:
token_labels.append(self.tokenizer.vocab["[UNK]"])
else:
token_labels.append(-1)
for token_id, token in enumerate(masked_tokens):
if token_id == 0 or token_id == len(masked_tokens) - 1:
token_labels.append(-1)
continue
prob = random.random()
if prob < 0.15:
prob /= 0.15
if prob < 0.8:
masked_tokens[token_id] = "[MASK]"
elif prob < 0.9:
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
try:
token_labels.append(self.tokenizer.vocab[token])
except KeyError:
token_labels.append(self.tokenizer.vocab["[UNK]"])
else:
token_labels.append(-1)
return masked_tokens, token_labels
def _get_text(self, video_id, n_pair_max, sub_ids=None, only_sim=False, enhance_vmodel=False):
caption = self.caption[video_id]
data_dict = self.data_dict[video_id]
if self.use_mil:
k = len(sub_ids)
r_ind = sub_ids
else:
n_caption = len(caption['start'])
n_caption = len(data_dict['start'])
if n_pair_max == -1:
k = n_caption
r_ind = range(n_caption)
@ -181,16 +160,10 @@ class Youtube_Transcript_DataLoader(Dataset):
for i in range(k):
ind = r_ind[i]
words, start_, end_ = self._get_single_transcript(caption, ind, with_long_context=self.with_long_context)
words, start_, end_ = self._get_single_transcript(data_dict, ind, with_long_context=self.with_long_context)
caption_words = words.copy()
starts[i], ends[i] = start_, end_
orig2token_tuple_list, chunk_positions = None, None
# # For entity mask
# orig_sentence, orig_to_token_map = generate_org_sentence_from_tokens(words.copy())
# four_tuples_ = generate_knoledge_candidate(orig_sentence, orig_to_token_map)
# _, orig2token_tuple_list, chunk_positions, _ = four_tuples_
if enhance_vmodel:
words = [] # mask all input text
@ -199,7 +172,7 @@ class Youtube_Transcript_DataLoader(Dataset):
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
words = words + ["[SEP]"]
input_ids = self.tokenizer.convert_tokens_to_ids(words)
input_mask = [1] * len(input_ids)
segment_ids = [0] * len(input_ids)
@ -216,17 +189,13 @@ class Youtube_Transcript_DataLoader(Dataset):
pairs_segment[i] = np.array(segment_ids)
if only_sim is False:
# # For entity mask : +1 due to [CLS]
# orig2token_tuple_list = [(s_+1, len_) for s_, len_ in orig2token_tuple_list if s_+len_ <= total_length_with_CLS]
# chunk_positions = [ind for ind in chunk_positions if ind<len(orig2token_tuple_list)]
# For generate captions
if len(caption_words) > total_length_with_CLS:
caption_words = caption_words[:total_length_with_CLS]
input_caption_words = ["[CLS]"] + caption_words
output_caption_words = caption_words + ["[SEP]"]
masked_tokens, token_labels = self._mask_tokens(words, orig2token_tuple_list, chunk_positions)
masked_tokens, token_labels = self._mask_tokens(words)
masked_token_ids = self.tokenizer.convert_tokens_to_ids(masked_tokens)
masked_input_caption_words, input_token_labels = self._mask_tokens(input_caption_words)
input_caption_words = masked_input_caption_words.copy()
@ -259,16 +228,16 @@ class Youtube_Transcript_DataLoader(Dataset):
return pairs_text, pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels, \
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids, starts, ends
def _get_single_transcript(self, caption, ind, with_long_context=True):
def _get_single_transcript(self, data_dict, ind, with_long_context=True):
start, end = ind, ind
words = self.tokenizer.tokenize(str(caption['text'][ind]))
diff = caption['end'][end] - caption['start'][start]
words = self.tokenizer.tokenize(str(data_dict['text'][ind]))
diff = data_dict['end'][end] - data_dict['start'][start]
while with_long_context and (len(words) < self.min_words or diff < self.min_time):
if start > 0 and end < len(caption['end']) - 1:
next_words = self.tokenizer.tokenize(str(caption['text'][end + 1]))
prev_words = self.tokenizer.tokenize(str(caption['text'][start - 1]))
d1 = caption['end'][end + 1] - caption['start'][start]
d2 = caption['end'][end] - caption['start'][start - 1]
if start > 0 and end < len(data_dict['end']) - 1:
next_words = self.tokenizer.tokenize(str(data_dict['text'][end + 1]))
prev_words = self.tokenizer.tokenize(str(data_dict['text'][start - 1]))
d1 = data_dict['end'][end + 1] - data_dict['start'][start]
d2 = data_dict['end'][end] - data_dict['start'][start - 1]
if (self.min_time > 0 and d2 <= d1) or \
(self.min_time == 0 and len(next_words) <= len(prev_words)):
start -= 1
@ -277,17 +246,17 @@ class Youtube_Transcript_DataLoader(Dataset):
end += 1
words.extend(next_words)
elif start > 0:
words = self.tokenizer.tokenize(str(caption['text'][start - 1])) + words
words = self.tokenizer.tokenize(str(data_dict['text'][start - 1])) + words
start -= 1
elif end < len(caption['end']) - 1:
words.extend(self.tokenizer.tokenize(str(caption['text'][end + 1])))
elif end < len(data_dict['end']) - 1:
words.extend(self.tokenizer.tokenize(str(data_dict['text'][end + 1])))
end += 1
else:
break
diff = caption['end'][end] - caption['start'][start]
return words, caption['start'][start], caption['end'][end]
diff = data_dict['end'][end] - data_dict['start'][start]
return words, data_dict['start'][start], data_dict['end'][end]
def _expand_video_slice(self, s, e, si, ei, fps, video_features, fps_k):
def _expand_video_slice(self, s, e, si, ei, fps, video_features):
start = int(s[si] * fps)
end = int(e[ei] * fps) + 1
@ -311,12 +280,6 @@ class Youtube_Transcript_DataLoader(Dataset):
start, end = end, start
video_slice = video_features[start:end]
# # Note: to alignment the features between 2d and 3d, due to the fps is different.
# if fps_k == "2d":
# indices = sorted([idx for idx in range(len(video_slice)) if idx % 2 == 1]
# + [idx for idx in range(len(video_slice))])
# video_slice = np.take(video_slice, indices, axis=0)
if self.max_frames < video_slice.shape[0]:
video_slice = video_slice[:self.max_frames]
@ -326,32 +289,20 @@ class Youtube_Transcript_DataLoader(Dataset):
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
max_video_length = [0] * len(s)
f_title = "feature_file_2D"
fps_k = "2d"
video = np.zeros((len(s), self.max_frames, self.feature_size), dtype=np.float)
feature_file = os.path.join(self.feature_path[fps_k], self.csv[f_title].values[idx])
feature_file = os.path.join(self.features_path, self.csv["feature_file"].values[idx])
try:
video_features = np.load(feature_file)
for i in range(len(s)):
# start = int(s[i] * self.fps[fps_k])
# end = int(e[i] * self.fps[fps_k]) + 1
# video_slice = video_features[start:end]
#
# if self.max_frames < video_slice.shape[0]:
# video_slice = video_slice[:self.max_frames]
if len(video_features) < 1:
raise ValueError("{} is empty.".format(feature_file))
video_slice, start, end = self._expand_video_slice(s, e, i, i,
self.fps[fps_k], video_features, fps_k)
video_slice, start, end = self._expand_video_slice(s, e, i, i, self.feature_framerate, video_features)
slice_shape = video_slice.shape
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
if len(video_slice) < 1:
pass
# video_slice = video_features
# if self.max_frames < video_slice.shape[0]:
# video_slice = video_slice[:self.max_frames]
else:
video[i][:slice_shape[0]] = video_slice
except Exception as e:
@ -387,7 +338,7 @@ class Youtube_Transcript_DataLoader(Dataset):
return "%02d:%02d:%02d" % (h, m2, s)
def __getitem__(self, feature_idx):
if self.sampled_use_mil: # sample from each video, has a high priority than use_mil.
if self.sampled_use_mil: # sample from each video, has a higher priority than use_mil.
idx = feature_idx
video_id = self.csv['video_id'].values[idx]
sub_list = self.iter2video_pairslist_dict[video_id]

Просмотреть файл

@ -0,0 +1,352 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
import os
from torch.utils.data import Dataset
import numpy as np
import pickle
import pandas as pd
from collections import defaultdict
import json
import random
class MSRVTT_DataLoader(Dataset):
"""MSRVTT dataset loader."""
def __init__(
self,
csv_path,
features_path,
tokenizer,
max_words=30,
feature_framerate=1.0,
max_frames=100,
):
self.data = pd.read_csv(csv_path)
self.feature_dict = pickle.load(open(features_path, 'rb'))
self.feature_framerate = feature_framerate
self.max_words = max_words
self.max_frames = max_frames
self.tokenizer = tokenizer
self.feature_size = self.feature_dict[self.data['video_id'].values[0]].shape[-1]
def __len__(self):
return len(self.data)
def _get_text(self, video_id, sentence):
choice_video_ids = [video_id]
n_caption = len(choice_video_ids)
k = n_caption
pairs_text = np.zeros((k, self.max_words), dtype=np.long)
pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
pairs_masked_text = np.zeros((k, self.max_words), dtype=np.long)
pairs_token_labels = np.zeros((k, self.max_words), dtype=np.long)
for i, video_id in enumerate(choice_video_ids):
words = self.tokenizer.tokenize(sentence)
words = ["[CLS]"] + words
total_length_with_CLS = self.max_words - 1
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
words = words + ["[SEP]"]
# Mask Language Model <-----
token_labels = []
masked_tokens = words.copy()
for token_id, token in enumerate(masked_tokens):
if token_id == 0 or token_id == len(masked_tokens) - 1:
token_labels.append(-1)
continue
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
prob /= 0.15
# 80% randomly change token to mask token
if prob < 0.8:
masked_tokens[token_id] = "[MASK]"
# 10% randomly change token to random token
elif prob < 0.9:
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
# -> rest 10% randomly keep current token
# append current token to output (we will predict these later)
try:
token_labels.append(self.tokenizer.vocab[token])
except KeyError:
# For unknown words (should not occur with BPE vocab)
token_labels.append(self.tokenizer.vocab["[UNK]"])
# print("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
else:
# no masking token (will be ignored by loss function later)
token_labels.append(-1)
# -----> Mask Language Model
input_ids = self.tokenizer.convert_tokens_to_ids(words)
input_mask = [1] * len(input_ids)
segment_ids = [0] * len(input_ids)
masked_token_ids = self.tokenizer.convert_tokens_to_ids(masked_tokens)
while len(input_ids) < self.max_words:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
masked_token_ids.append(0)
token_labels.append(-1)
assert len(input_ids) == self.max_words
assert len(input_mask) == self.max_words
assert len(segment_ids) == self.max_words
assert len(masked_token_ids) == self.max_words
assert len(token_labels) == self.max_words
pairs_text[i] = np.array(input_ids)
pairs_mask[i] = np.array(input_mask)
pairs_segment[i] = np.array(segment_ids)
pairs_masked_text[i] = np.array(masked_token_ids)
pairs_token_labels[i] = np.array(token_labels)
return pairs_text, pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels, choice_video_ids
def _get_video(self, choice_video_ids):
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long)
max_video_length = [0] * len(choice_video_ids)
video = np.zeros((len(choice_video_ids), self.max_frames, self.feature_size), dtype=np.float)
for i, video_id in enumerate(choice_video_ids):
video_slice = self.feature_dict[video_id]
if self.max_frames < video_slice.shape[0]:
video_slice = video_slice[:self.max_frames]
slice_shape = video_slice.shape
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
if len(video_slice) < 1:
print("video_id: {}".format(video_id))
else:
video[i][:slice_shape[0]] = video_slice
for i, v_length in enumerate(max_video_length):
video_mask[i][:v_length] = [1] * v_length
# Mask Frame Model <-----
video_labels_index = [[] for _ in range(len(choice_video_ids))]
masked_video = video.copy()
for i, video_pair_ in enumerate(masked_video):
for j, _ in enumerate(video_pair_):
if j < max_video_length[i]:
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
masked_video[i][j] = [0.] * video.shape[-1]
video_labels_index[i].append(j)
else:
video_labels_index[i].append(-1)
else:
video_labels_index[i].append(-1)
video_labels_index = np.array(video_labels_index, dtype=np.long)
# -----> Mask Frame Model
return video, video_mask, masked_video, video_labels_index
def __getitem__(self, idx):
video_id = self.data['video_id'].values[idx]
sentence = self.data['sentence'].values[idx]
pairs_text, pairs_mask, pairs_segment, \
pairs_masked_text, pairs_token_labels, choice_video_ids = self._get_text(video_id, sentence)
video, video_mask, masked_video, video_labels_index = self._get_video(choice_video_ids)
return pairs_text, pairs_mask, pairs_segment, video, video_mask, \
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index
class MSRVTT_TrainDataLoader(Dataset):
"""MSRVTT train dataset loader."""
def __init__(
self,
csv_path,
json_path,
features_path,
tokenizer,
max_words=30,
feature_framerate=1.0,
max_frames=100,
unfold_sentences=False,
):
self.csv = pd.read_csv(csv_path)
self.data = json.load(open(json_path, 'r'))
self.feature_dict = pickle.load(open(features_path, 'rb'))
self.feature_framerate = feature_framerate
self.max_words = max_words
self.max_frames = max_frames
self.tokenizer = tokenizer
self.feature_size = self.feature_dict[self.csv['video_id'].values[0]].shape[-1]
self.unfold_sentences = unfold_sentences
self.sample_len = 0
if self.unfold_sentences:
train_video_ids = list(self.csv['video_id'].values)
self.sentences_dict = {}
for itm in self.data['sentences']:
if itm['video_id'] in train_video_ids:
self.sentences_dict[len(self.sentences_dict)] = (itm['video_id'], itm['caption'])
self.sample_len = len(self.sentences_dict)
else:
num_sentences = 0
self.sentences = defaultdict(list)
s_video_id_set = set()
for itm in self.data['sentences']:
self.sentences[itm['video_id']].append(itm['caption'])
num_sentences += 1
s_video_id_set.add(itm['video_id'])
# Use to find the clips in the same video
self.parent_ids = {}
self.children_video_ids = defaultdict(list)
for itm in self.data['videos']:
vid = itm["video_id"]
url_posfix = itm["url"].split("?v=")[-1]
self.parent_ids[vid] = url_posfix
self.children_video_ids[url_posfix].append(vid)
self.sample_len = len(self.csv)
def __len__(self):
return self.sample_len
def _get_text(self, video_id, caption=None):
k = 1
choice_video_ids = [video_id]
pairs_text = np.zeros((k, self.max_words), dtype=np.long)
pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
pairs_masked_text = np.zeros((k, self.max_words), dtype=np.long)
pairs_token_labels = np.zeros((k, self.max_words), dtype=np.long)
for i, video_id in enumerate(choice_video_ids):
if caption is not None:
words = self.tokenizer.tokenize(caption)
else:
words = self._get_single_text(video_id)
words = ["[CLS]"] + words
total_length_with_CLS = self.max_words - 1
if len(words) > total_length_with_CLS:
words = words[:total_length_with_CLS]
words = words + ["[SEP]"]
# Mask Language Model <-----
token_labels = []
masked_tokens = words.copy()
for token_id, token in enumerate(masked_tokens):
if token_id == 0 or token_id == len(masked_tokens) - 1:
token_labels.append(-1)
continue
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
prob /= 0.15
# 80% randomly change token to mask token
if prob < 0.8:
masked_tokens[token_id] = "[MASK]"
# 10% randomly change token to random token
elif prob < 0.9:
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
# -> rest 10% randomly keep current token
# append current token to output (we will predict these later)
try:
token_labels.append(self.tokenizer.vocab[token])
except KeyError:
# For unknown words (should not occur with BPE vocab)
token_labels.append(self.tokenizer.vocab["[UNK]"])
# print("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
else:
# no masking token (will be ignored by loss function later)
token_labels.append(-1)
# -----> Mask Language Model
input_ids = self.tokenizer.convert_tokens_to_ids(words)
input_mask = [1] * len(input_ids)
segment_ids = [0] * len(input_ids)
masked_token_ids = self.tokenizer.convert_tokens_to_ids(masked_tokens)
while len(input_ids) < self.max_words:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
masked_token_ids.append(0)
token_labels.append(-1)
assert len(input_ids) == self.max_words
assert len(input_mask) == self.max_words
assert len(segment_ids) == self.max_words
assert len(masked_token_ids) == self.max_words
assert len(token_labels) == self.max_words
pairs_text[i] = np.array(input_ids)
pairs_mask[i] = np.array(input_mask)
pairs_segment[i] = np.array(segment_ids)
pairs_masked_text[i] = np.array(masked_token_ids)
pairs_token_labels[i] = np.array(token_labels)
return pairs_text, pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels, choice_video_ids
def _get_single_text(self, video_id):
rind = random.randint(0, len(self.sentences[video_id]) - 1)
caption = self.sentences[video_id][rind]
words = self.tokenizer.tokenize(caption)
return words
def _get_video(self, choice_video_ids):
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long)
max_video_length = [0] * len(choice_video_ids)
video = np.zeros((len(choice_video_ids), self.max_frames, self.feature_size), dtype=np.float)
for i, video_id in enumerate(choice_video_ids):
video_slice = self.feature_dict[video_id]
if self.max_frames < video_slice.shape[0]:
video_slice = video_slice[:self.max_frames]
slice_shape = video_slice.shape
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
if len(video_slice) < 1:
print("video_id: {}".format(video_id))
else:
video[i][:slice_shape[0]] = video_slice
for i, v_length in enumerate(max_video_length):
video_mask[i][:v_length] = [1] * v_length
# Mask Frame Model <-----
video_labels_index = [[] for _ in range(len(choice_video_ids))]
masked_video = video.copy()
for i, video_pair_ in enumerate(masked_video):
for j, _ in enumerate(video_pair_):
if j < max_video_length[i]:
prob = random.random()
# mask token with 15% probability
if prob < 0.15:
masked_video[i][j] = [0.] * video.shape[-1]
video_labels_index[i].append(j)
else:
video_labels_index[i].append(-1)
else:
video_labels_index[i].append(-1)
video_labels_index = np.array(video_labels_index, dtype=np.long)
# -----> Mask Frame Model
return video, video_mask, masked_video, video_labels_index
def __getitem__(self, idx):
if self.unfold_sentences:
video_id, caption = self.sentences_dict[idx]
else:
video_id, caption = self.csv['video_id'].values[idx], None
pairs_text, pairs_mask, pairs_segment, \
pairs_masked_text, pairs_token_labels, choice_video_ids = self._get_text(video_id, caption)
video, video_mask, masked_video, video_labels_index = self._get_video(choice_video_ids)
return pairs_text, pairs_mask, pairs_segment, video, video_mask, \
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index

Просмотреть файл

@ -7,52 +7,35 @@ from torch.utils.data import Dataset
import pandas as pd
import os
import numpy as np
import pickle
import re
import random
import io
class Youcook_Transcript_NoPair_DataLoader(Dataset):
class Youcook_Caption_DataLoader(Dataset):
"""Youcook dataset loader."""
def __init__(
self,
csv,
data_path,
features_path,
caption,
tokenizer,
min_time=10.0,
features_path_3D=None,
feature_framerate=1.0,
feature_framerate_3D=24.0 / 16.0,
max_words=30,
min_words=0,
n_pair=-1, # -1 for test
pad_token="[PAD]",
max_frames=100,
):
"""
Args:
"""
self.csv = pd.read_csv(csv)
self.features_path_2D = features_path
self.features_path_3D = features_path_3D
self.caption = caption
self.min_time = min_time
self.data_dict = pickle.load(open(data_path, 'rb'))
self.feature_dict = pickle.load(open(features_path, 'rb'))
self.feature_framerate = feature_framerate
self.feature_framerate_3D = feature_framerate_3D
self.max_words = max_words
self.max_frames = max_frames
self.min_words = min_words
self.tokenizer = tokenizer
self.n_pair = n_pair
self.pad_token = pad_token
self.fps = {'2d': feature_framerate, '3d': feature_framerate_3D}
self.feature_path = {'2d': features_path}
if features_path_3D != '':
self.feature_path['3d'] = features_path_3D
_feature_file = os.path.join(features_path, self.csv["feature_file_2D"].values[0])
self.feature_size = np.load(_feature_file).shape[-1]
self.feature_size = self.feature_dict[self.csv["feature_file"].values[0]].shape[-1]
# Get iterator video ids
video_id_list = [itm for itm in self.csv['video_id'].values]
@ -61,8 +44,8 @@ class Youcook_Transcript_NoPair_DataLoader(Dataset):
self.iter2video_pairs_dict = {}
iter_idx_ = 0
for video_id in video_id_list:
caption = self.caption[video_id]
n_caption = len(caption['start'])
data_dict = self.data_dict[video_id]
n_caption = len(data_dict['start'])
for sub_id in range(n_caption):
self.iter2video_pairs_dict[iter_idx_] = (video_id, sub_id)
iter_idx_ += 1
@ -71,7 +54,7 @@ class Youcook_Transcript_NoPair_DataLoader(Dataset):
return len(self.iter2video_pairs_dict)
def _get_text(self, video_id, sub_id):
caption = self.caption[video_id]
data_dict = self.data_dict[video_id]
k = 1
r_ind = [sub_id]
@ -89,10 +72,10 @@ class Youcook_Transcript_NoPair_DataLoader(Dataset):
for i in range(k):
ind = r_ind[i]
start_, end_ = caption['start'][ind], caption['end'][ind]
start_, end_ = data_dict['start'][ind], data_dict['end'][ind]
starts[i], ends[i] = start_, end_
total_length_with_CLS = self.max_words - 1
words = self.tokenizer.tokenize(caption['transcript'][ind])
words = self.tokenizer.tokenize(data_dict['transcript'][ind])
words = ["[CLS]"] + words
if len(words) > total_length_with_CLS:
@ -156,7 +139,7 @@ class Youcook_Transcript_NoPair_DataLoader(Dataset):
pairs_token_labels[i] = np.array(token_labels)
# For generate captions
caption_words = self.tokenizer.tokenize(caption['text'][ind])
caption_words = self.tokenizer.tokenize(data_dict['text'][ind])
if len(caption_words) > total_length_with_CLS:
caption_words = caption_words[:total_length_with_CLS]
input_caption_words = ["[CLS]"] + caption_words
@ -184,16 +167,12 @@ class Youcook_Transcript_NoPair_DataLoader(Dataset):
def _get_video(self, idx, s, e):
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
max_video_length = [0] * len(s)
f_title = "feature_file_2D"
fps_k = "2d"
feature_file = os.path.join(self.feature_path[fps_k], self.csv[f_title].values[idx])
video_features = np.load(feature_file)
video_features = self.feature_dict[self.csv["feature_file"].values[idx]]
video = np.zeros((len(s), self.max_frames, self.feature_size), dtype=np.float)
for i in range(len(s)):
start = int(s[i] * self.fps[fps_k])
end = int(e[i] * self.fps[fps_k]) + 1
start = int(s[i] * self.feature_framerate)
end = int(e[i] * self.feature_framerate) + 1
video_slice = video_features[start:end]
if self.max_frames < video_slice.shape[0]:
@ -202,7 +181,7 @@ class Youcook_Transcript_NoPair_DataLoader(Dataset):
slice_shape = video_slice.shape
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
if len(video_slice) < 1:
print("video_id: {}, start: {}, end: {}".format(feature_file, start, end))
print("video_id: {}, start: {}, end: {}".format(self.csv["video_id"].values[idx], start, end))
# pass
else:
video[i][:slice_shape[0]] = video_slice

Просмотреть файл

@ -7,46 +7,31 @@ from torch.utils.data import Dataset
import pandas as pd
import os
import numpy as np
import pickle
import random
class Youcook_NoPair_DataLoader(Dataset):
class Youcook_DataLoader(Dataset):
"""Youcook dataset loader."""
def __init__(
self,
csv,
data_path,
features_path,
caption,
tokenizer,
min_time=10.0,
features_path_3D=None,
feature_framerate=1.0,
feature_framerate_3D=24.0 / 16.0,
max_words=30,
min_words=0,
n_pair=-1, # -1 for test
pad_token="[PAD]",
max_frames=100,
):
"""
Args:
"""
self.csv = pd.read_csv(csv)
self.features_path_2D = features_path
self.features_path_3D = features_path_3D
self.caption = caption
self.min_time = min_time
self.data_dict = pickle.load(open(data_path, 'rb'))
self.feature_dict = pickle.load(open(features_path, 'rb'))
self.feature_framerate = feature_framerate
self.feature_framerate_3D = feature_framerate_3D
self.max_words = max_words
self.max_frames = max_frames
self.min_words = min_words
self.tokenizer = tokenizer
self.n_pair = n_pair
self.pad_token = pad_token
self.fps = {'2d': feature_framerate, '3d': feature_framerate_3D}
self.feature_path = {'2d': features_path}
if features_path_3D != '':
self.feature_path['3d'] = features_path_3D
# Get iterator video ids
video_id_list = [itm for itm in self.csv['video_id'].values]
@ -55,8 +40,8 @@ class Youcook_NoPair_DataLoader(Dataset):
self.iter2video_pairs_dict = {}
iter_idx_ = 0
for video_id in video_id_list:
caption = self.caption[video_id]
n_caption = len(caption['start'])
data_dict = self.data_dict[video_id]
n_caption = len(data_dict['start'])
for sub_id in range(n_caption):
self.iter2video_pairs_dict[iter_idx_] = (video_id, sub_id)
iter_idx_ += 1
@ -65,9 +50,8 @@ class Youcook_NoPair_DataLoader(Dataset):
return len(self.iter2video_pairs_dict)
def _get_text(self, video_id, sub_id):
caption = self.caption[video_id]
k = 1
r_ind = [sub_id]
data_dict = self.data_dict[video_id]
k, r_ind = 1, [sub_id]
starts = np.zeros(k)
ends = np.zeros(k)
@ -79,9 +63,8 @@ class Youcook_NoPair_DataLoader(Dataset):
for i in range(k):
ind = r_ind[i]
# Note: n_pair_max=-1 means eval
words = self.tokenizer.tokenize(caption['text'][ind])
start_, end_ = caption['start'][ind], caption['end'][ind]
words = self.tokenizer.tokenize(data_dict['text'][ind])
start_, end_ = data_dict['start'][ind], data_dict['end'][ind]
starts[i], ends[i] = start_, end_
words = ["[CLS]"] + words
@ -151,16 +134,12 @@ class Youcook_NoPair_DataLoader(Dataset):
def _get_video(self, idx, s, e):
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
max_video_length = [0] * len(s)
f_title = "feature_file_2D"
fps_k = "2d"
feature_file = os.path.join(self.feature_path[fps_k], self.csv[f_title].values[idx])
video_features = np.load(feature_file)
video_features = self.feature_dict[self.csv["feature_file"].values[idx]]
video = np.zeros((len(s), self.max_frames, video_features.shape[-1]), dtype=np.float)
for i in range(len(s)):
start = int(s[i] * self.fps[fps_k])
end = int(e[i] * self.fps[fps_k]) + 1
start = int(s[i] * self.feature_framerate)
end = int(e[i] * self.feature_framerate) + 1
video_slice = video_features[start:end]
if self.max_frames < video_slice.shape[0]:
@ -169,7 +148,7 @@ class Youcook_NoPair_DataLoader(Dataset):
slice_shape = video_slice.shape
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
if len(video_slice) < 1:
print("video_id: {}, start: {}, end: {}".format(feature_file, start, end))
print("video_id: {}, start: {}, end: {}".format(self.csv["video_id"].values[idx], start, end))
else:
video[i][:slice_shape[0]] = video_slice

410
main_pretrain.py Normal file
Просмотреть файл

@ -0,0 +1,410 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
import torch
from torch.utils.data import (SequentialSampler)
import numpy as np
import random
import os
from collections import OrderedDict
import pickle
import logging
import time
import argparse
from modules.tokenization import BertTokenizer
from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from modules.modeling import UniVL
from modules.optimization import BertAdam
from dataloader_howto100m import Youtube_DataLoader
from torch.utils.data import DataLoader
from util import get_logger
torch.distributed.init_process_group(backend="nccl")
global logger
def get_args(description='UniVL on Pretrain'):
parser = argparse.ArgumentParser(description=description)
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument('--train_csv', type=str, default='data/HowTo100M_v1.csv', help='train csv')
parser.add_argument('--features_path', type=str, default='feature', help='feature path for 2D features')
parser.add_argument('--data_path', type=str, default='data/data.pickle', help='data pickle file path')
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--max_words', type=int, default=20, help='')
parser.add_argument('--max_frames', type=int, default=100, help='')
parser.add_argument('--min_words', type=int, default=0, help='')
parser.add_argument('--feature_framerate', type=int, default=1, help='')
parser.add_argument('--min_time', type=float, default=5.0, help='Gather small clips')
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--bert_model", default="bert-base-uncased", type=str, required=True,
help="Bert pre-trained model")
parser.add_argument("--visual_model", default="visual-base", type=str, required=False, help="Visual module")
parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
parser.add_argument("--decoder_model", default="decoder-base", type=str, required=False, help="Decoder module")
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--world_size", default=0, type=int, help="distribted training")
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
parser.add_argument('--coef_lr', type=float, default=0.1, help='coefficient for bert branch.')
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether use MIL, has a high priority than use_mil.")
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
parser.add_argument('--visual_num_hidden_layers', type=int, default=6, help="Layer NO. of visual.")
parser.add_argument('--cross_num_hidden_layers', type=int, default=2, help="Layer NO. of cross.")
parser.add_argument('--decoder_num_hidden_layers', type=int, default=3, help="Layer NO. of decoder.")
parser.add_argument('--stage_two', action='store_true', help="Whether training with decoder.")
parser.add_argument('--pretrain_enhance_vmodal', action='store_true', help="Enhance visual and other modalities when pretraining.")
parser.add_argument("--load_checkpoint", action="store_true")
parser.add_argument("--checkpoint_model", default="pytorch_model.bin.checkpoint", type=str, required=False,
help="Save the last model as a checkpoint.")
args = parser.parse_args()
if args.sampled_use_mil: # sample from each video, has a higher priority than use_mil.
args.use_mil = True
# Check paramenters
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
if not args.do_pretrain:
raise ValueError("`do_pretrain` must be True.")
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
args.checkpoint_model = '{}_{}_{}_{}.checkpoint'.format(args.checkpoint_model, args.bert_model, args.max_words, args.max_frames)
return args
def set_seed_logger(args):
global logger
random.seed(args.seed)
os.environ['PYTHONHASHSEED'] = str(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(args.local_rank)
args.world_size = world_size
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
if args.local_rank == 0:
logger.info("Effective parameters:")
for key in sorted(args.__dict__):
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
return args
def init_device(args, local_rank):
global logger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank)
n_gpu = torch.cuda.device_count()
logger.info("device: {} n_gpu: {}".format(device, n_gpu))
args.n_gpu = n_gpu
if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0:
raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format(
args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu))
return device, n_gpu
def init_model(args, device, n_gpu, local_rank):
if args.init_model:
model_state_dict = torch.load(args.init_model, map_location='cpu')
else:
model_state_dict = None
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
return model
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.):
if hasattr(model, 'module'):
model = model.module
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
no_decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)]
decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)]
no_decay_bert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." in n]
no_decay_nobert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." not in n]
decay_bert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." in n]
decay_nobert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." not in n]
optimizer_grouped_parameters = [
{'params': [p for n, p in no_decay_bert_param_tp], 'weight_decay': 0.01, 'lr': args.lr * coef_lr},
{'params': [p for n, p in no_decay_nobert_param_tp], 'weight_decay': 0.01},
{'params': [p for n, p in decay_bert_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
{'params': [p for n, p in decay_nobert_param_tp], 'weight_decay': 0.0}
]
scheduler = None
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
schedule='warmup_linear', t_total=num_train_optimization_steps, weight_decay=0.01,
max_grad_norm=1.0)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
output_device=local_rank, find_unused_parameters=True)
return optimizer, scheduler, model
def dataloader_pretrain(args, tokenizer, only_sim=False):
if args.local_rank == 0:
logger.info('Loading captions: {}'.format(args.data_path))
data_dict = pickle.load(open(args.data_path, 'rb'))
if args.local_rank == 0:
logger.info('Done, data_dict length: {}'.format(len(data_dict)))
dataset = Youtube_DataLoader(
csv=args.train_csv,
features_path=args.features_path,
data_dict=data_dict,
min_time=args.min_time,
max_words=args.max_words,
min_words=args.min_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
n_pair=args.n_pair,
max_frames=args.max_frames,
use_mil=args.use_mil,
only_sim=only_sim,
sampled_use_mil=args.sampled_use_mil,
pretrain_enhance_vmodal=args.pretrain_enhance_vmodal,
video_dim=args.video_dim,
)
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=(sampler is None),
sampler=sampler,
drop_last=True,
)
return dataloader, len(dataset), sampler
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict
def save_model(epoch, args, model, local_rank, type_name="", global_step=-1, optimizer=None):
model_to_save = model.module if hasattr(model, 'module') else model
output_model_file = os.path.join(
args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
torch.save(model_to_save.state_dict(), output_model_file)
logger.info("Model saved to %s", output_model_file)
if global_step != -1 and optimizer is not None:
state_dict = {
'epoch': epoch,
'global_step': global_step,
'model_state_dict': model_to_save.state_dict(),
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
}
checkpoint_model_file = os.path.join(args.output_dir, args.checkpoint_model)
torch.save(state_dict, checkpoint_model_file)
logger.info("Checkpoint is saved. use `load_checkpoint` to recovery it.")
return output_model_file
def load_model(epoch, args, n_gpu, device, model, global_step=0, model_file=None):
if model_file is None or len(model_file) == 0:
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
last_optim_state = None
checkpoint_model_file = os.path.join(args.output_dir, args.checkpoint_model)
if epoch == -1 and args.load_checkpoint and os.path.exists(checkpoint_model_file):
checkpoint_state = torch.load(checkpoint_model_file, map_location='cpu')
epoch = checkpoint_state['epoch']
global_step = checkpoint_state['global_step']
model_state_dict = checkpoint_state['model_state_dict']
last_optim_state = checkpoint_state['last_optimizer_state']
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
if args.local_rank == 0:
logger.info("Checkpoint loaded from %s", checkpoint_model_file)
elif os.path.exists(model_file):
model_state_dict = torch.load(model_file, map_location='cpu')
if args.local_rank == 0:
logger.info("Model loaded from %s", model_file)
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
return epoch, global_step, last_optim_state, model
def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=0):
global logger
torch.cuda.empty_cache()
model.train()
log_step = args.n_display
start_time = time.time()
total_loss = 0
for step, batch in enumerate(train_dataloader):
batch = tuple(t.to(device=device, non_blocking=True) for t in batch)
input_ids, input_mask, segment_ids, video, video_mask, \
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index,\
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids = batch
loss = model(input_ids, segment_ids, input_mask, video, video_mask,
pairs_masked_text=pairs_masked_text, pairs_token_labels=pairs_token_labels,
masked_video=masked_video, video_labels_index=video_labels_index,
input_caption_ids=pairs_input_caption_ids, decoder_mask=pairs_decoder_mask,
output_caption_ids=pairs_output_caption_ids)
if n_gpu > 1:
loss = loss.mean()
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
total_loss += float(loss)
if (step + 1) % args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if scheduler is not None:
scheduler.step()
optimizer.step()
optimizer.zero_grad()
global_step += 1
if global_step % log_step == 0 and local_rank == 0:
logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
args.epochs, step + 1,
len(train_dataloader), "-".join([str('%.6f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]),
float(loss) * args.gradient_accumulation_steps,
(time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
start_time = time.time()
total_loss = total_loss / len(train_dataloader)
return total_loss, global_step
def main():
global logger
args = get_args()
args = set_seed_logger(args)
device, n_gpu = init_device(args, args.local_rank)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
model = init_model(args, device, n_gpu, args.local_rank)
only_sim = model.module._stage_one if hasattr(model, 'module') else model._stage_one
train_dataloader, train_length, sampler = dataloader_pretrain(args, tokenizer, only_sim=only_sim)
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
/ args.gradient_accumulation_steps) * args.epochs
global_step = 0
epoch = -1
last_optim_state = None
if args.load_checkpoint:
epoch, global_step, last_optim_state, model = load_model(epoch, args, n_gpu, device, model, global_step=global_step)
epoch += 1
if args.local_rank == 0:
logger.warning("Will continue to epoch: {}".format(epoch))
epoch = 0 if epoch < 0 else epoch
coef_lr = args.coef_lr
if args.init_model:
coef_lr = 1.0
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr)
if last_optim_state is not None:
optimizer.load_state_dict(last_optim_state)
if args.local_rank == 0:
logger.info("***** Running pretraining *****")
logger.info(" Num examples = %d", train_length)
logger.info(" Batch size = %d", args.batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
iter_ls_ = [itm for itm in range(args.epochs) if itm >= epoch]
for epoch in iter_ls_:
sampler.set_epoch(epoch)
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
scheduler, global_step, local_rank=args.local_rank)
if args.local_rank == 0:
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
save_model(epoch, args, model, args.local_rank, type_name="pretrain", global_step=global_step, optimizer=optimizer)
if __name__ == "__main__":
main()

Просмотреть файл

@ -4,62 +4,40 @@ from __future__ import unicode_literals
from __future__ import print_function
import torch
from torch.utils.data import (RandomSampler, SequentialSampler, TensorDataset)
from data_prefetch_unitl import DataLoaderX as DataLoader # Enhanced Loader
from torch.utils.data import (SequentialSampler)
import numpy as np
import random
import os
from collections import OrderedDict
from youcook_transcript_nopair_dataloader import Youcook_Transcript_NoPair_DataLoader
from youtube_transcript_dataloader import Youtube_Transcript_DataLoader
from rouge import Rouge
from nlgeval import compute_metrics, NLGEval
from nlgeval import NLGEval
import pickle
import logging
import time
import argparse
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import VLBert
from pytorch_pretrained_bert.optimization_bert import BertAdam
from pytorch_pretrained_bert.beam import Beam
from modules.tokenization import BertTokenizer
from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from modules.modeling import UniVL
from modules.optimization import BertAdam
from modules.beam import Beam
from torch.utils.data import DataLoader
from dataloader_youcook_caption import Youcook_Caption_DataLoader
from util import get_logger
torch.distributed.init_process_group(backend="nccl")
global amp_pck_loaded_
try:
from apex import amp
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
# remove the need for this code, but it is still valid.
amp.register_half_function(torch, 'einsum')
amp_pck_loaded_ = True
except ImportError:
amp_pck_loaded_ = False
def get_logger(filename=None):
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
if filename is not None:
handler = logging.FileHandler(filename)
handler.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logging.getLogger().addHandler(handler)
return logger
global logger
def get_args(description='Youtube-Text-Video'):
def get_args(description='UniVL on Caption Task'):
parser = argparse.ArgumentParser(description=description)
parser.add_argument('--train_csv', type=str, default='data/HowTo100M_v1.csv', help='train csv')
parser.add_argument('--features_path_2D', type=str, default='feature_2d', help='feature path for 2D features')
parser.add_argument('--features_path_3D', type=str, default='feature_3d', help='feature path for 3D features')
parser.add_argument('--caption_path', type=str, default='data/caption.pickle', help='caption pickle file path')
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument('--train_csv', type=str, default='data/youcookii_singlef_train.csv', help='')
parser.add_argument('--val_csv', type=str, default='data/youcookii_singlef_val.csv', help='')
parser.add_argument('--data_path', type=str, default='data/youcookii_caption_transcript.pickle',
help='caption and transcription pickle file path')
parser.add_argument('--features_path', type=str, default='data/youcookii_videos_feature.pickle',
help='feature path for 2D features')
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
@ -72,7 +50,6 @@ def get_args(description='Youtube-Text-Video'):
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--max_words', type=int, default=20, help='')
parser.add_argument('--max_frames', type=int, default=100, help='')
parser.add_argument('--min_words', type=int, default=0, help='')
parser.add_argument('--feature_framerate', type=int, default=1, help='')
parser.add_argument('--min_time', type=float, default=5.0, help='Gather small clips')
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
@ -80,34 +57,16 @@ def get_args(description='Youtube-Text-Video'):
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
parser.add_argument('--youcook', type=int, default=0, help='Train on YouCook2 data')
parser.add_argument('--eval_youcook', type=int, default=0, help='Evaluate on YouCook2 data')
parser.add_argument('--youcook_train_csv', type=str, default='data/youcookii_singlef_train.csv', help='')
parser.add_argument('--youcook_val_csv', type=str, default='data/youcookii_singlef_val.csv', help='')
parser.add_argument('--youcook_caption_path', type=str, default='data/youcookii_caption_transcript.pickle', help='youcookii caption and transcription pickle file path')
parser.add_argument('--youcook_features_path_2D', type=str, default='data/youcookii_videos_feature2d', help='youcookii feature path for 2D features')
parser.add_argument('--youcook_features_path_3D', type=str, default='data/youcookii_videos_feature3d', help='youcookii feature path for 3D features')
parser.add_argument('--pad_token', type=str, default='[PAD]', help='')
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--visual_bert_model", default="visual-bert-base", type=str, required=False,
help="VisualBert pre-trained model selected in the list: visual-bert-base, visual-bert-large")
parser.add_argument("--cross_bert_model", default="cross-bert-base", type=str, required=False,
help="CrossBert pre-trained model selected in the list: cross-bert-base, cross-bert-large")
parser.add_argument("--decoder_bert_model", default="decoder-bert-base", type=str, required=False,
help="DecoderBert pre-trained model selected in the list: decoder-bert-base, decoder-bert-large")
parser.add_argument("--bert_model", default="bert-base-uncased", type=str, required=True, help="Bert pre-trained model")
parser.add_argument("--visual_model", default="visual-base", type=str, required=False, help="Visual module")
parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
parser.add_argument("--decoder_model", default="decoder-base", type=str, required=False, help="Decoder module")
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.");
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
@ -121,52 +80,32 @@ def get_args(description='Youtube-Text-Video'):
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--task_type", default="caption", type=str, help="Point the task `retrieval` or `caption` to finetune.")
parser.add_argument("--task_type", default="caption", type=str, help="Point the task `caption` to finetune.")
parser.add_argument("--datatype", default="youcook", type=str, help="Point the dataset `youcook` to finetune.")
parser.add_argument("--world_size", default=0, type=int, help="distribted training")
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
parser.add_argument('--coef_lr', type=float, default=0.1, help='coefficient for bert branch.')
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether use MIL, has a high priority than use_mil.")
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
parser.add_argument('--visual_num_hidden_layers', type=int, default=1, help="Layer NO. of visual.")
parser.add_argument('--visual_num_hidden_layers', type=int, default=6, help="Layer NO. of visual.")
parser.add_argument('--cross_num_hidden_layers', type=int, default=2, help="Layer NO. of cross.")
parser.add_argument('--decoder_num_hidden_layers', type=int, default=1, help="Layer NO. of decoder.")
parser.add_argument('--cross_model', action='store_true', help="Whether training with decoder.")
parser.add_argument('--pretrain_with_joint_sim', action='store_true', help="Whether using joint embedding when pretraining.")
parser.add_argument('--pretrain_enhance_vmodal', action='store_true', help="Enhance visual and other modalities when pretraining.")
parser.add_argument('--without_sim_in_decoder', action='store_true', help="Whether align in decoder when training.")
parser.add_argument('--pretrain_without_decoder', action='store_true', help="Whether ignore decode when pretraining.")
parser.add_argument("--load_checkpoint", action="store_true")
parser.add_argument("--checkpoint_model", default="pytorch_model.bin.checkpoint", type=str, required=False,
help="Save the last model as a checkpoint.")
parser.add_argument('--decoder_num_hidden_layers', type=int, default=3, help="Layer NO. of decoder.")
parser.add_argument('--stage_two', action='store_true', help="Whether training with decoder.")
args = parser.parse_args()
if args.sampled_use_mil: # sample from each video, has a high priority than use_mil.
args.use_mil = True
if args.do_pretrain is False:
args.pretrain_without_decoder = False
global amp_pck_loaded_
if args.fp16:
if not amp_pck_loaded_:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
# Check paramenters
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
if not args.do_pretrain and not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_pretrain` or `do_train` or `do_eval` must be True.")
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
args.checkpoint_model = '{}_{}_{}_{}.checkpoint'.format(args.checkpoint_model, args.bert_model, args.max_words, args.max_frames)
return args
def set_seed_logger(args):
@ -180,24 +119,22 @@ def set_seed_logger(args):
torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# torch.multiprocessing.set_sharing_strategy('file_system') # RuntimeError: received 0 items of ancdata.
world_size = torch.distributed.get_world_size()
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
args.local_rank = local_rank
torch.cuda.set_device(args.local_rank)
args.world_size = world_size
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
if local_rank == 0:
if args.local_rank == 0:
logger.info("Effective parameters:")
for key in sorted(args.__dict__):
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
return args, world_size, local_rank
return args
def init_device(args, local_rank):
global logger
@ -223,7 +160,7 @@ def init_model(args, device, n_gpu, local_rank):
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
@ -259,33 +196,19 @@ def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, loc
schedule='warmup_linear', t_total=num_train_optimization_steps, weight_decay=0.01,
max_grad_norm=1.0)
if args.fp16:
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if n_gpu > 1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
output_device=local_rank, find_unused_parameters=True)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
output_device=local_rank, find_unused_parameters=True)
return optimizer, scheduler, model
def dataloader_youcook_train(args, tokenizer):
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
logger.info('Done, caption length: {}'.format(len(caption)))
youcook_dataset = Youcook_Transcript_NoPair_DataLoader(
csv=args.youcook_train_csv,
features_path=args.youcook_features_path_2D,
features_path_3D=args.youcook_features_path_3D,
caption=caption,
min_time=args.min_time,
youcook_dataset = Youcook_Caption_DataLoader(
csv=args.train_csv,
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
min_words=args.min_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
n_pair=-1,
pad_token=args.pad_token,
max_frames=args.max_frames,
)
@ -303,22 +226,13 @@ def dataloader_youcook_train(args, tokenizer):
return dataloader, len(youcook_dataset), train_sampler
def dataloader_youcook_test(args, tokenizer):
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
logger.info('Done, caption length: {}'.format(len(caption)))
youcook_testset = Youcook_Transcript_NoPair_DataLoader(
csv=args.youcook_val_csv,
features_path=args.youcook_features_path_2D,
features_path_3D=args.youcook_features_path_3D,
caption=caption,
min_time=args.min_time,
youcook_testset = Youcook_Caption_DataLoader(
csv=args.val_csv,
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
min_words=args.min_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
n_pair=-1,
pad_token=args.pad_token,
max_frames=args.max_frames,
)
@ -331,47 +245,10 @@ def dataloader_youcook_test(args, tokenizer):
pin_memory=False,
)
logger.info('YoucookII validation pairs: {}'.format(len(youcook_testset)))
if args.local_rank == 0:
logger.info('YoucookII validation pairs: {}'.format(len(youcook_testset)))
return dataloader_youcook, len(youcook_testset)
def dataloader_pretrain(args, tokenizer, only_sim=False):
logger.info('Loading captions: {}'.format(args.caption_path))
caption = pickle.load(open(args.caption_path, 'rb'))
logger.info('Done, caption length: {}'.format(len(caption)))
dataset = Youtube_Transcript_DataLoader(
csv=args.train_csv,
features_path=args.features_path_2D,
features_path_3D=args.features_path_3D,
caption=caption,
min_time=args.min_time,
max_words=args.max_words,
min_words=args.min_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
n_pair=args.n_pair,
pad_token=args.pad_token,
max_frames=args.max_frames,
use_mil=args.use_mil,
only_sim=only_sim,
sampled_use_mil=args.sampled_use_mil,
pretrain_enhance_vmodal=args.pretrain_enhance_vmodal,
video_dim=args.video_dim,
)
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=(sampler is None),
sampler=sampler,
drop_last=True,
)
return dataloader, len(dataset), sampler
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
@ -385,63 +262,31 @@ def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
else:
return state_dict
def save_model(epoch, args, model, local_rank, type_name="", global_step=-1, optimizer=None):
def save_model(epoch, args, model, type_name=""):
# Only save the model it-self
model_to_save = model.module if hasattr(model, 'module') else model
output_model_file = os.path.join(
args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
torch.save(model_to_save.state_dict(), output_model_file)
logger.info("Model saved to %s", output_model_file)
if global_step != -1 and optimizer is not None:
amp_state_dict = {}
if args.fp16:
amp_state_dict = amp.state_dict()
state_dict = {
'epoch': epoch,
'global_step': global_step,
'model_state_dict': model_to_save.state_dict(),
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'amp_state_dict': amp_state_dict,
}
checkpoint_model_file = os.path.join(args.output_dir, args.checkpoint_model)
torch.save(state_dict, checkpoint_model_file)
logger.info("Checkpoint is saved. use `load_checkpoint` to recovery it.")
return output_model_file
def load_model(epoch, args, n_gpu, device, model, global_step=0, model_file=None):
def load_model(epoch, args, n_gpu, device, model_file=None):
if model_file is None or len(model_file) == 0:
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
last_optim_state = None
amp_state_dict = None
checkpoint_model_file = os.path.join(args.output_dir, args.checkpoint_model)
if epoch == -1 and args.load_checkpoint and os.path.exists(checkpoint_model_file):
checkpoint_state = torch.load(checkpoint_model_file, map_location='cpu')
epoch = checkpoint_state['epoch']
global_step = checkpoint_state['global_step']
model_state_dict = checkpoint_state['model_state_dict']
last_optim_state = checkpoint_state['last_optimizer_state']
if args.fp16 and 'amp_state_dict' in checkpoint_state:
amp_state_dict = checkpoint_state['amp_state_dict']
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
logger.info("Checkpoint loaded from %s", checkpoint_model_file)
elif os.path.exists(model_file):
if os.path.exists(model_file):
model_state_dict = torch.load(model_file, map_location='cpu')
logger.info("Model loaded from %s", model_file)
if args.local_rank == 0:
logger.info("Model loaded from %s", model_file)
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
return epoch, global_step, last_optim_state, amp_state_dict, model
else:
model = None
return model
def train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer, scheduler,
global_step, nlgEvalObj=None, local_rank=0):
@ -473,19 +318,12 @@ def train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu,
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
loss.backward()
total_loss += float(loss)
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if scheduler is not None:
scheduler.step() # Update learning rate schedule
@ -602,12 +440,12 @@ def collect_hypothesis_and_scores(inst_dec_beams, n_best):
return all_hyp, all_scores
# >----------------------------------------
def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=None, nlgEvalObj=None, test_set=None):
def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, nlgEvalObj=None, test_set=None):
if hasattr(model, 'module'):
model = model.module.to(device)
if model._choice_sim:
if model._stage_one:
return 0.
all_result_lists = []
@ -628,12 +466,12 @@ def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=
n_inst, len_s, d_h = sequence_output.size()
_, len_v, v_h = visual_output.size()
decoder = model.decoder_caption # This is a decoder function
decoder = model.decoder_caption
# Note: shaped first, then decoder need the parameter shaped=True
input_ids = input_ids.view(-1, input_ids.shape[-1]) # [batch_size*n_pair, self.max_words]
input_mask = input_mask.view(-1, input_mask.shape[-1]) # [batch_size*n_pair, self.max_words]
video_mask = video_mask.view(-1, video_mask.shape[-1]) # [batch_size*n_pair, self.max_frames]
input_ids = input_ids.view(-1, input_ids.shape[-1])
input_mask = input_mask.view(-1, input_mask.shape[-1])
video_mask = video_mask.view(-1, video_mask.shape[-1])
sequence_output_rpt = sequence_output.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)
visual_output_rpt = visual_output.repeat(1, n_bm, 1).view(n_inst * n_bm, len_v, v_h)
@ -696,7 +534,7 @@ def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=
writer.write("{}\t{}\t{}\n".format("video_id", "start_time", "caption"))
for idx, pre_txt in enumerate(all_result_lists):
video_id, sub_id = test_set.iter2video_pairs_dict[idx]
start_time = test_set.caption[video_id]['start'][sub_id]
start_time = test_set.data_dict[video_id]['start'][sub_id]
writer.write("{}\t{}\t{}\n".format(video_id, start_time, pre_txt))
logger.info("File of complete results is saved in {}".format(hyp_path))
@ -711,17 +549,7 @@ def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=
for ground_txt in all_caption_lists:
writer.write(ground_txt + "\n")
# Filter out hyps of 0 length
hyps_and_refs = zip(all_result_lists, all_caption_lists)
hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0
and len([" ".join(sub_s.split()) for sub_s in _[0].split(".") if len(sub_s) > 0]) > 0]
all_result_lists, all_caption_lists = zip(*hyps_and_refs)
# Evaluate
metrics_dict = rougeObj.get_scores(hyps=all_result_lists, refs=all_caption_lists, avg=True, ignore_empty=True)
logger.info(">>> rouge_1f: {:.4f}, rouge_2f: {:.4f}, rouge_lf: {:.4f}".
format(metrics_dict["rouge-1"]["f"], metrics_dict["rouge-2"]["f"], metrics_dict["rouge-l"]["f"]))
metrics_nlg = nlgEvalObj.compute_metrics(ref_list=[all_caption_lists], hyp_list=all_result_lists)
logger.info(">>> BLEU_1: {:.4f}, BLEU_2: {:.4f}, BLEU_3: {:.4f}, BLEU_4: {:.4f}".
format(metrics_nlg["Bleu_1"], metrics_nlg["Bleu_2"], metrics_nlg["Bleu_3"], metrics_nlg["Bleu_4"]))
@ -736,83 +564,34 @@ DATALOADER_DICT["youcook"] = {"train":dataloader_youcook_train, "val":dataloader
def main():
global logger
args = get_args()
args, world_size, local_rank = set_seed_logger(args)
device, n_gpu = init_device(args, local_rank)
args = set_seed_logger(args)
device, n_gpu = init_device(args, args.local_rank)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
model = init_model(args, device, n_gpu, args.local_rank)
assert args.task_type == "caption"
nlgEvalObj = NLGEval(no_overlap=False, no_skipthoughts=True, no_glove=True, metrics_to_omit=None)
assert args.datatype in DATALOADER_DICT
test_dataloader, test_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer)
if args.local_rank == 0:
logger.info("***** Running test *****")
logger.info(" Num examples = %d", test_length)
logger.info(" Batch size = %d", args.batch_size_val)
logger.info(" Num steps = %d", len(test_dataloader))
if args.do_train:
args.task_type = "caption"
model = init_model(args, device, n_gpu, local_rank)
only_sim = model.module._choice_sim if hasattr(model, 'module') else model._choice_sim
if args.task_type == "caption":
rougeObj = Rouge()
nlgEvalObj = NLGEval(no_overlap=False, no_skipthoughts=True, no_glove=True, metrics_to_omit=None)
datatype = "youcook"
assert datatype in DATALOADER_DICT
if args.do_pretrain is False:
test_dataloader, test_length = DATALOADER_DICT[datatype]["val"](args, tokenizer)
if local_rank == 0:
logger.info("***** Running test *****")
logger.info(" Num examples = %d", test_length)
logger.info(" Batch size = %d", args.batch_size_val)
logger.info(" Num steps = %d", len(test_dataloader))
if args.do_pretrain:
train_dataloader, train_length, sampler = dataloader_pretrain(args, tokenizer, only_sim=only_sim)
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
/ args.gradient_accumulation_steps) * args.epochs
global_step = 0
epoch = -1
last_optim_state = None
amp_state_dict = None
if args.load_checkpoint:
epoch, global_step, last_optim_state, amp_state_dict, model = load_model(epoch, args, n_gpu, device, model, global_step=global_step)
epoch += 1
if local_rank == 0:
logger.warning("Will continue to epoch: {}".format(epoch))
epoch = 0 if epoch < 0 else epoch
coef_lr = args.coef_lr
if args.init_model:
coef_lr = 1.0
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=coef_lr)
if last_optim_state is not None:
optimizer.load_state_dict(last_optim_state)
if amp_state_dict is not None:
amp.load_state_dict(amp_state_dict)
if local_rank == 0:
logger.info("***** Running pretraining *****")
logger.info(" Num examples = %d", train_length)
logger.info(" Batch size = %d", args.batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
iter_ls_ = [itm for itm in range(args.epochs) if itm >= epoch]
for epoch in iter_ls_:
sampler.set_epoch(epoch)
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer,
scheduler, global_step, nlgEvalObj=nlgEvalObj, local_rank=local_rank)
if local_rank == 0:
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
save_model(epoch, args, model, local_rank, type_name="pretrain", global_step=global_step, optimizer=optimizer)
elif args.do_train:
train_dataloader, train_length, train_sampler = DATALOADER_DICT[datatype]["train"](args, tokenizer)
train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer)
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
/ args.gradient_accumulation_steps) * args.epochs
coef_lr = args.coef_lr
if args.init_model:
coef_lr = 1.0
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=coef_lr)
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr)
if local_rank == 0:
if args.local_rank == 0:
logger.info("***** Running training *****")
logger.info(" Num examples = %d", train_length)
logger.info(" Batch size = %d", args.batch_size)
@ -825,13 +604,13 @@ def main():
train_sampler.set_epoch(epoch)
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer,
scheduler, global_step, nlgEvalObj=nlgEvalObj, local_rank=local_rank)
scheduler, global_step, nlgEvalObj=nlgEvalObj, local_rank=args.local_rank)
if local_rank == 0:
if args.local_rank == 0:
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
output_model_file = save_model(epoch, args, model, local_rank, type_name="")
output_model_file = save_model(epoch, args, model, type_name="")
if epoch > 0:
Bleu_4 = eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=rougeObj, nlgEvalObj=nlgEvalObj)
Bleu_4 = eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, nlgEvalObj=nlgEvalObj)
if best_score <= Bleu_4:
best_score = Bleu_4
best_output_model_file = output_model_file
@ -839,12 +618,12 @@ def main():
else:
logger.warning("Skip the evaluation after {}-th epoch.".format(epoch+1))
if local_rank == 0:
_, _, _, _, model = load_model(-1, args, n_gpu, device, model, model_file=best_output_model_file)
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=rougeObj, nlgEvalObj=nlgEvalObj)
if args.local_rank == 0:
model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, nlgEvalObj=nlgEvalObj)
elif args.do_eval:
if local_rank == 0:
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=rougeObj, nlgEvalObj=nlgEvalObj)
if args.local_rank == 0:
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, nlgEvalObj=nlgEvalObj)
if __name__ == "__main__":
main()

Просмотреть файл

@ -4,43 +4,39 @@ from __future__ import unicode_literals
from __future__ import print_function
import torch
from torch.utils.data import (RandomSampler, SequentialSampler, TensorDataset)
# from torch.utils.data import DataLoader
from data_prefetch_unitl import DataLoaderX as DataLoader # Enhanced Loader
from torch.utils.data import (SequentialSampler)
import numpy as np
import random
import os
from youcook_nopair_dataloader import Youcook_NoPair_DataLoader
from metrics import compute_metrics, print_computed_metrics
from metrics import compute_metrics
import pickle
import logging
import time
import argparse
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import VLBert
from pytorch_pretrained_bert.optimization_bert import BertAdam
from util import parallel_apply
def get_logger(filename=None):
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
if filename is not None:
handler = logging.FileHandler(filename)
handler.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logging.getLogger().addHandler(handler)
return logger
from modules.tokenization import BertTokenizer
from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from modules.modeling import UniVL
from modules.optimization import BertAdam
from torch.utils.data import DataLoader
from util import parallel_apply, get_logger
from dataloader_youcook_retrieval import Youcook_DataLoader
from dataloader_msrvtt_retrieval import MSRVTT_DataLoader
from dataloader_msrvtt_retrieval import MSRVTT_TrainDataLoader
torch.distributed.init_process_group(backend="nccl")
global logger
def get_args(description='Youtube-Text-Video'):
def get_args(description='UniVL on Retrieval Task'):
parser = argparse.ArgumentParser(description=description)
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument('--train_csv', type=str, default='data/youcookii_singlef_train.csv', help='')
parser.add_argument('--val_csv', type=str, default='data/youcookii_singlef_val.csv', help='')
parser.add_argument('--data_path', type=str, default='data/youcookii_caption.pickle', help='data pickle file path')
parser.add_argument('--features_path', type=str, default='data/youcookii_videos_feature.pickle', help='feature path')
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
@ -52,42 +48,23 @@ def get_args(description='Youtube-Text-Video'):
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--max_words', type=int, default=20, help='')
parser.add_argument('--max_frames', type=int, default=100, help='')
parser.add_argument('--min_words', type=int, default=0, help='')
parser.add_argument('--feature_framerate', type=int, default=1, help='')
parser.add_argument('--min_time', type=float, default=5.0, help='Gather small clips')
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
parser.add_argument('--youcook', type=int, default=0, help='Train on YouCook2 data')
parser.add_argument('--eval_youcook', type=int, default=0, help='Evaluate on YouCook2 data')
parser.add_argument('--youcook_train_csv', type=str, default='data/youcookii_singlef_train.csv', help='')
parser.add_argument('--youcook_val_csv', type=str, default='data/youcookii_singlef_val.csv', help='')
parser.add_argument('--youcook_caption_path', type=str, default='data/youcookii_caption.pickle', help='youcookii caption pickle file path')
parser.add_argument('--youcook_features_path_2D', type=str, default='data/youcookii_videos_feature2d', help='youcookii 2D feature path')
parser.add_argument('--youcook_features_path_3D', type=str, default='data/youcookii_videos_feature3d', help='youcookii 3D feature path')
parser.add_argument('--pad_token', type=str, default='[PAD]', help='')
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--visual_bert_model", default="visual-bert-base", type=str, required=False,
help="VisualBert pre-trained model selected in the list: visual-bert-base, visual-bert-large")
parser.add_argument("--cross_bert_model", default="cross-bert-base", type=str, required=False,
help="CrossBert pre-trained model selected in the list: cross-bert-base, cross-bert-large")
parser.add_argument("--decoder_bert_model", default="decoder-bert-base", type=str, required=False,
help="DecoderBert pre-trained model selected in the list: decoder-bert-base, decoder-bert-large")
parser.add_argument("--bert_model", default="bert-base-uncased", type=str, required=True,
help="Bert pre-trained model")
parser.add_argument("--visual_model", default="visual-base", type=str, required=False, help="Visual module")
parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
parser.add_argument("--decoder_model", default="decoder-base", type=str, required=False, help="Decoder module")
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.");
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
@ -103,28 +80,29 @@ def get_args(description='Youtube-Text-Video'):
parser.add_argument("--task_type", default="retrieval", type=str, help="Point the task `retrieval` to finetune.")
parser.add_argument("--datatype", default="youcook", type=str, help="Point the dataset `youcook` to finetune.")
parser.add_argument("--world_size", default=0, type=int, help="distribted training")
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
parser.add_argument('--coef_lr', type=float, default=0.1, help='coefficient for bert branch.')
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether MIL, has a high priority than use_mil.")
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
parser.add_argument('--visual_num_hidden_layers', type=int, default=1, help="Layer NO. of visual.")
parser.add_argument('--visual_num_hidden_layers', type=int, default=6, help="Layer NO. of visual.")
parser.add_argument('--cross_num_hidden_layers', type=int, default=2, help="Layer NO. of cross.")
parser.add_argument('--decoder_num_hidden_layers', type=int, default=1, help="Layer NO. of decoder.")
parser.add_argument('--decoder_num_hidden_layers', type=int, default=3, help="Layer NO. of decoder.")
parser.add_argument('--train_sim_after_cross', action='store_true', help="Test retrieval after cross encoder.")
parser.add_argument('--expand_msrvtt_sentences', action='store_true', help="")
args = parser.parse_args()
if args.sampled_use_mil: # sample from each video, has a high priority than use_mil.
args.use_mil = True
# Check paramenters
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
if not args.do_pretrain and not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_pretrain` or `do_train` or `do_eval` must be True.")
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
@ -142,17 +120,27 @@ def set_seed_logger(args):
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(args.local_rank)
args.world_size = world_size
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
os.makedirs(args.output_dir, exist_ok=True)
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
logger.info("Effective parameters:")
for key in sorted(args.__dict__):
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
def init_device(args):
if args.local_rank == 0:
logger.info("Effective parameters:")
for key in sorted(args.__dict__):
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
return args
def init_device(args, local_rank):
global logger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank)
n_gpu = torch.cuda.device_count()
logger.info("device: {} n_gpu: {}".format(device, n_gpu))
args.n_gpu = n_gpu
@ -163,7 +151,7 @@ def init_device(args):
return device, n_gpu
def init_model(args, device, n_gpu):
def init_model(args, device, n_gpu, local_rank):
if args.init_model:
model_state_dict = torch.load(args.init_model, map_location='cpu')
@ -172,24 +160,14 @@ def init_model(args, device, n_gpu):
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
# remove the need for this code, but it is still valid.
if args.fp16:
try:
import apex
apex.amp.register_half_function(torch, 'einsum')
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
return model
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, coef_lr=1.0):
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.):
if hasattr(model, 'module'):
model = model.module
@ -212,72 +190,49 @@ def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, coe
{'params': [p for n, p in decay_bert_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
{'params': [p for n, p in decay_nobert_param_tp], 'weight_decay': 0.0}
]
scheduler = None
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
schedule='warmup_linear', t_total=num_train_optimization_steps, weight_decay=0.01,
max_grad_norm=1.0)
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if n_gpu > 1:
model = torch.nn.DataParallel(model)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
output_device=local_rank, find_unused_parameters=True)
return optimizer, scheduler, model
def dataloader_youcook_train(args, tokenizer):
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
logger.info('Done, caption length: {}'.format(len(caption)))
args.n_pair = 1
youcook_dataset = Youcook_NoPair_DataLoader(
csv=args.youcook_train_csv,
features_path=args.youcook_features_path_2D,
features_path_3D=args.youcook_features_path_3D,
caption=caption,
min_time=args.min_time,
youcook_dataset = Youcook_DataLoader(
csv=args.train_csv,
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
min_words=args.min_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
n_pair=args.n_pair,
pad_token=args.pad_token,
max_frames=args.max_frames,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(youcook_dataset)
dataloader = DataLoader(
youcook_dataset,
batch_size=args.batch_size,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
shuffle=True,
pin_memory=False,
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=True,
)
return dataloader, len(youcook_dataset)
return dataloader, len(youcook_dataset), train_sampler
def dataloader_youcook_test(args, tokenizer):
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
logger.info('Done, caption length: {}'.format(len(caption)))
youcook_testset = Youcook_NoPair_DataLoader(
csv=args.youcook_val_csv,
features_path=args.youcook_features_path_2D,
features_path_3D=args.youcook_features_path_3D,
caption=caption,
min_time=args.min_time,
youcook_testset = Youcook_DataLoader(
csv=args.val_csv,
data_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
min_words=args.min_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
n_pair=-1,
pad_token=args.pad_token,
max_frames=args.max_frames,
)
@ -293,6 +248,49 @@ def dataloader_youcook_test(args, tokenizer):
return dataloader_youcook, len(youcook_testset)
def dataloader_msrvtt_train(args, tokenizer):
msrvtt_dataset = MSRVTT_TrainDataLoader(
csv_path=args.train_csv,
json_path=args.data_path,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
unfold_sentences=args.expand_msrvtt_sentences,
)
train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset)
dataloader = DataLoader(
msrvtt_dataset,
batch_size=args.batch_size // args.n_gpu,
num_workers=args.num_thread_reader,
pin_memory=False,
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=True,
)
return dataloader, len(msrvtt_dataset), train_sampler
def dataloader_msrvtt_test(args, tokenizer):
msrvtt_testset = MSRVTT_DataLoader(
csv_path=args.val_csv,
features_path=args.features_path,
max_words=args.max_words,
feature_framerate=args.feature_framerate,
tokenizer=tokenizer,
max_frames=args.max_frames,
)
dataloader_msrvtt = DataLoader(
msrvtt_testset,
batch_size=args.batch_size_val,
num_workers=args.num_thread_reader,
shuffle=False,
drop_last=False,
)
return dataloader_msrvtt, len(msrvtt_testset)
def save_model(epoch, args, model, type_name=""):
# Only save the model it-self
model_to_save = model.module if hasattr(model, 'module') else model
@ -307,10 +305,11 @@ def load_model(epoch, args, n_gpu, device, model_file=None):
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
if os.path.exists(model_file):
model_state_dict = torch.load(model_file, map_location='cpu')
logger.info("Model loaded from %s", model_file)
if args.local_rank == 0:
logger.info("Model loaded from %s", model_file)
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
model.to(device)
@ -318,24 +317,18 @@ def load_model(epoch, args, n_gpu, device, model_file=None):
model = None
return model
def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step):
def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=0):
global logger
torch.cuda.empty_cache()
model.train()
log_step = 100
log_step = args.n_display
start_time = time.time()
total_loss = 0
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
for step, batch in enumerate(train_dataloader):
if n_gpu == 1:
# multi-gpu does scattering it-self
batch = tuple(t.to(device) for t in batch)
batch = tuple(t.to(device=device, non_blocking=True) for t in batch)
input_ids, input_mask, segment_ids, video, video_mask, \
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index = batch
@ -348,19 +341,12 @@ def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
loss.backward()
total_loss += float(loss)
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if scheduler is not None:
scheduler.step() # Update learning rate schedule
@ -369,10 +355,11 @@ def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
optimizer.zero_grad()
global_step += 1
if global_step % log_step == 0:
if global_step % log_step == 0 and local_rank == 0:
logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
args.epochs, step + 1,
len(train_dataloader), "-".join([str('%.6f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]), loss,
len(train_dataloader), "-".join([str('%.6f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]),
float(loss) * args.gradient_accumulation_steps,
(time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
start_time = time.time()
@ -418,7 +405,6 @@ def eval_epoch(args, model, test_dataloader, device, n_gpu):
print("{}/{}\r".format(bid, len(test_dataloader)), end="")
start_time = time.time()
if n_gpu > 1:
device_ids = list(range(n_gpu))
batch_list_t_splits = []
@ -455,7 +441,6 @@ def eval_epoch(args, model, test_dataloader, device, n_gpu):
sim_matrix += parallel_outputs[idx]
else:
sim_matrix = _run_on_single_gpu(model, batch_list, batch_list, batch_sequence_output_list, batch_visual_output_list)
logger.info("%f" % (time.time() - start_time))
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
metrics = compute_metrics(sim_matrix)
@ -468,62 +453,65 @@ def eval_epoch(args, model, test_dataloader, device, n_gpu):
DATALOADER_DICT = {}
DATALOADER_DICT["youcook"] = {"train":dataloader_youcook_train, "val":dataloader_youcook_test}
DATALOADER_DICT["msrvtt"] = {"train":dataloader_msrvtt_train, "val":dataloader_msrvtt_test}
def main():
global logger
args = get_args()
set_seed_logger(args)
device, n_gpu = init_device(args)
args = set_seed_logger(args)
device, n_gpu = init_device(args, args.local_rank)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
assert args.task_type == "retrieval"
model = init_model(args, device, n_gpu, args.local_rank)
assert args.datatype in DATALOADER_DICT
test_dataloader, test_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer)
if args.local_rank == 0:
logger.info("***** Running test *****")
logger.info(" Num examples = %d", test_length)
logger.info(" Batch size = %d", args.batch_size_val)
logger.info(" Num steps = %d", len(test_dataloader))
if args.do_train:
args.task_type = "retrieval"
model = init_model(args, device, n_gpu)
datatype = args.datatype
assert datatype in DATALOADER_DICT
test_dataloader, test_length = DATALOADER_DICT[datatype]["val"](args, tokenizer)
logger.info("***** Running test *****")
logger.info(" Num examples = %d", test_length)
logger.info(" Batch size = %d", args.batch_size_val)
logger.info(" Num steps = %d", len(test_dataloader))
if args.do_pretrain:
raise NotImplementedError("Deleted~ no use")
elif args.do_train:
train_dataloader, train_length = DATALOADER_DICT[datatype]["train"](args, tokenizer)
train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer)
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
/ args.gradient_accumulation_steps) * args.epochs
coef_lr = 0.1
coef_lr = args.coef_lr
if args.init_model:
coef_lr = 1.0
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, coef_lr=coef_lr)
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", train_length)
logger.info(" Batch size = %d", args.batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps)
if args.local_rank == 0:
logger.info("***** Running training *****")
logger.info(" Num examples = %d", train_length)
logger.info(" Batch size = %d", args.batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
best_score = 0.00001
best_output_model_file = None
global_step = 0
for epoch in range(args.epochs):
train_sampler.set_epoch(epoch)
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
scheduler, global_step)
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
output_model_file = save_model(epoch, args, model, type_name="")
scheduler, global_step, local_rank=args.local_rank)
if args.local_rank == 0:
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
output_model_file = save_model(epoch, args, model, type_name="")
R1 = eval_epoch(args, model, test_dataloader, device, n_gpu)
if best_score <= R1:
best_score = R1
best_output_model_file = output_model_file
logger.info("The best model is: {}, the R1 is: {:.4f}".format(best_output_model_file, best_score))
model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
eval_epoch(args, model, test_dataloader, device, n_gpu)
R1 = eval_epoch(args, model, test_dataloader, device, n_gpu)
if best_score <= R1:
best_score = R1
best_output_model_file = output_model_file
logger.info("The best model is: {}, the R1 is: {:.4f}".format(best_output_model_file, best_score))
if args.local_rank == 0:
model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
eval_epoch(args, model, test_dataloader, device, n_gpu)
elif args.do_eval:
eval_epoch(args, model, test_dataloader, device, n_gpu)
if args.local_rank == 0:
eval_epoch(args, model, test_dataloader, device, n_gpu)
if __name__ == "__main__":
main()

Просмотреть файл

@ -19,7 +19,6 @@ def compute_metrics(x):
metrics['MR'] = np.median(ind) + 1
return metrics
def print_computed_metrics(metrics):
r1 = metrics['R1']
r5 = metrics['R5']

Просмотреть файл

Просмотреть файл

@ -1,12 +1,11 @@
""" Manage beam search info structure.
Heavily borrowed from OpenNMT-py.
For code in OpenNMT-py, please check the following link:
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py
"""
Manage beam search info structure.
Heavily borrowed from OpenNMT-py.
For code in OpenNMT-py, please check the following link (maybe in oldest version):
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py
"""
import torch
import numpy as np
class Constants():
def __init__(self):
@ -14,7 +13,7 @@ class Constants():
self.UNK = 1
self.BOS = 2
self.EOS = 3
self.PAD_WORD = '[MASK]'
self.PAD_WORD = '[PAD]'
self.UNK_WORD = '[UNK]'
self.BOS_WORD = '[CLS]'
self.EOS_WORD = '[SEP]'
@ -76,7 +75,7 @@ class Beam():
self.scores = best_scores
# bestScoresId is flattened as a (beam x word) array,
# so we need to calculate which word and beam each score came from
prev_k = best_scores_id / num_words
prev_k = best_scores_id // num_words
self.prev_ks.append(prev_k)
self.next_ys.append(best_scores_id - prev_k * num_words)
# End condition is when top-of-beam is EOS.

Просмотреть файл

Просмотреть файл

@ -19,14 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import numpy as np
import torch
@ -34,38 +27,22 @@ from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from pytorch_pretrained_bert.bert_module import PreTrainedBertModel, BertModel, BertLayerNorm, BertOnlyMLMHead
from pytorch_pretrained_bert.visual_module import PreTrainedVisualBertModel, VisualBertModel, VisualBertLayerNorm, VisualBertOnlyMLMHead
from pytorch_pretrained_bert.cross_module import PreTrainedCrossBertModel, CrossBertModel, CrossBertLayerNorm
from pytorch_pretrained_bert.decoder_module import PreTrainedDecoderBertModel, DecoderBertModel, DecoderBertLayerNorm
from modules.until_module import PreTrainedModel, LayerNorm, CrossEn, MILNCELoss, MaxMarginRankingLoss
from modules.module_bert import BertModel, BertConfig, BertOnlyMLMHead
from modules.module_visual import VisualModel, VisualConfig, VisualOnlyMLMHead
from modules.module_cross import CrossModel, CrossConfig
from modules.module_decoder import DecoderModel, DecoderConfig
logger = logging.getLogger(__name__)
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class PreTrainedModel(nn.Module):
class UniVLPreTrainedModel(PreTrainedModel, nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, bert_config, visual_config, cross_config, decoder_config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
# utilize bert config as base config
super(UniVLPreTrainedModel, self).__init__(bert_config)
self.bert_config = bert_config
self.visual_config = visual_config
self.cross_config = cross_config
@ -76,88 +53,8 @@ class PreTrainedModel(nn.Module):
self.cross = None
self.decoder = None
def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
elif isinstance(module, BertLayerNorm) \
or isinstance(module, LayerNorm) \
or isinstance(module, VisualBertLayerNorm) \
or isinstance(module, CrossBertLayerNorm) \
or isinstance(module, DecoderBertLayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def resize_token_embeddings(self, new_num_tokens=None):
raise ValueError("`TODO: Just bert has this function, call it at here.`")
@classmethod
def init_preweight(cls, model, state_dict, prefix=None, task_config=None):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
if prefix is not None:
old_keys = []
new_keys = []
for key in state_dict.keys():
old_keys.append(key)
new_keys.append(prefix+key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='')
if prefix is None and (task_config is None or task_config.local_rank == 0):
logger.info("-"*20)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
if len(error_msgs) > 0:
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
return model
@classmethod
def from_pretrained(cls, pretrained_model_name, pretrained_visual_model_name, pretrained_cross_model_name,
pretrained_decoder_model_name,
def from_pretrained(cls, pretrained_bert_name, visual_model_name, cross_model_name, decoder_model_name,
state_dict=None, cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
task_config = None
@ -168,94 +65,21 @@ class PreTrainedModel(nn.Module):
elif task_config.local_rank == -1:
task_config.local_rank = 0
bert_config, state_dict = PreTrainedBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=task_config)
# visual_state_dict is no use here
visual_config, _ = PreTrainedVisualBertModel.get_config(pretrained_visual_model_name, cache_dir, state_dict=None, task_config=task_config)
# visual_state_dict is no use here
cross_config, _ = PreTrainedCrossBertModel.get_config(pretrained_cross_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
decoder_config, _ = PreTrainedDecoderBertModel.get_config(pretrained_decoder_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
# Instantiate model.
bert_config, state_dict = BertConfig.get_config(pretrained_bert_name, cache_dir, type_vocab_size, state_dict, task_config=task_config)
visual_config, _ = VisualConfig.get_config(visual_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
cross_config, _ = CrossConfig.get_config(cross_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
decoder_config, _ = DecoderConfig.get_config(decoder_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
model = cls(bert_config, visual_config, cross_config, decoder_config, *inputs, **kwargs)
assert model.bert is not None
assert model.visual is not None
# assert model.cross is not None
# assert model.decoder is not None
if state_dict is not None:
model = cls.init_preweight(model, state_dict, task_config=task_config)
return model
class CrossEn(nn.Module):
def __init__(self,):
super(CrossEn, self).__init__()
def forward(self, sim_matrix):
logpt = F.log_softmax(sim_matrix, dim=-1)
logpt = torch.diag(logpt)
nce_loss = -logpt
sim_loss = nce_loss.mean()
return sim_loss
class MILNCELoss(nn.Module):
def __init__(self, batch_size=1, n_pair=1,):
super(MILNCELoss, self).__init__()
self.batch_size = batch_size
self.n_pair = n_pair
def forward(self, sim_matrix):
mm_mask = np.eye(self.batch_size)
mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair)))
mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device)
from_text_matrix = sim_matrix + mm_mask * -1e12
from_video_matrix = sim_matrix.transpose(1, 0) # video * text
new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1)
logpt = F.log_softmax(new_sim_matrix, dim=-1)
mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1)
masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12
new_logpt = -torch.logsumexp(masked_logpt, dim=-1)
logpt_choice = torch.zeros_like(new_logpt)
mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2)
logpt_choice[mark_ind] = 1
sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=torch.uint8)).mean()
return sim_loss
class MaxMarginRankingLoss(nn.Module):
def __init__(self,
margin=1.0,
negative_weighting=False,
batch_size=1,
n_pair=1,
hard_negative_rate=0.5,
):
super(MaxMarginRankingLoss, self).__init__()
self.margin = margin
self.n_pair = n_pair
self.batch_size = batch_size
easy_negative_rate = 1 - hard_negative_rate
self.easy_negative_rate = easy_negative_rate
self.negative_weighting = negative_weighting
if n_pair > 1 and batch_size > 1:
alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate))
mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha
mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair)))
mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate))
self.mm_mask = mm_mask.float()
def forward(self, x):
d = torch.diag(x)
max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \
F.relu(self.margin + x - d.view(1, -1))
if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1:
max_margin = max_margin * self.mm_mask.to(max_margin.device)
return max_margin.mean()
class NormalizeVideo(nn.Module):
def __init__(self, task_config):
super(NormalizeVideo, self).__init__()
@ -279,9 +103,12 @@ def update_attr(target_name, target_config, target_attr_name, source_config, sou
target_attr_name, getattr(target_config, target_attr_name)))
return target_config
class VLBert(PreTrainedModel):
def check_attr(target_name, task_config):
return hasattr(task_config, target_name) and task_config.__dict__[target_name]
class UniVL(UniVLPreTrainedModel):
def __init__(self, bert_config, visual_config, cross_config, decoder_config, task_config):
super(VLBert, self).__init__(bert_config, visual_config, cross_config, decoder_config)
super(UniVL, self).__init__(bert_config, visual_config, cross_config, decoder_config)
self.task_config = task_config
self.ignore_video_index = -1
@ -290,57 +117,54 @@ class VLBert(PreTrainedModel):
assert self.task_config.max_frames <= visual_config.max_position_embeddings
assert self.task_config.max_words + self.task_config.max_frames <= cross_config.max_position_embeddings
self._choice_sim = True
self._cross_model = False
self._stage_one = True
self._stage_two = False
if hasattr(self.task_config, 'cross_model') and self.task_config.cross_model:
self._choice_sim = False
self._cross_model = self.task_config.cross_model
if check_attr('stage_two', self.task_config):
self._stage_one = False
self._stage_two = self.task_config.stage_two
show_log(task_config, "Stage-One:{}, Stage-Two:{}".format(self._stage_one, self._stage_two))
self.train_sim_after_cross = False
if self._choice_sim and hasattr(self.task_config, 'train_sim_after_cross') and self.task_config.train_sim_after_cross:
if self._stage_one and check_attr('train_sim_after_cross', self.task_config):
self.train_sim_after_cross = True
show_log(task_config, "Test retrieval after cross encoder.")
self.with_sim_in_decoder = True
if hasattr(self.task_config, 'without_sim_in_decoder') and self.task_config.without_sim_in_decoder:
self.with_sim_in_decoder = False
show_log(task_config, "Set no sim after cross.")
self.pretrain_without_decoder = False
if hasattr(self.task_config, 'pretrain_without_decoder') and self.task_config.pretrain_without_decoder:
self.pretrain_without_decoder = True
show_log(task_config, "Set pretrain without decoder.")
show_log(task_config, "sim:{}, cross_model:{}".format(self._choice_sim, self._cross_model))
# The name should be `bert`, if dynamic vocabulary is needed
# Text Encoder ===>
bert_config = update_attr("bert_config", bert_config, "num_hidden_layers",
self.task_config, "text_num_hidden_layers")
self.bert = BertModel(bert_config)
bert_word_embeddings_weight = self.bert.embeddings.word_embeddings.weight
bert_position_embeddings_weight = self.bert.embeddings.position_embeddings.weight
# <=== End of Text Encoder
# Video Encoder ===>
visual_config = update_attr("visual_config", visual_config, "num_hidden_layers",
self.task_config, "visual_num_hidden_layers")
self.visual = VisualBertModel(visual_config)
self.visual = VisualModel(visual_config)
visual_word_embeddings_weight = self.visual.embeddings.word_embeddings.weight
# <=== End of Video Encoder
if self._choice_sim is False or self.train_sim_after_cross:
self.cls = BertOnlyMLMHead(bert_config, bert_word_embeddings_weight)
self.cls_visual = VisualBertOnlyMLMHead(visual_config, visual_word_embeddings_weight)
if self._stage_one is False or self.train_sim_after_cross:
# Cross Encoder ===>
cross_config = update_attr("cross_config", cross_config, "num_hidden_layers",
self.task_config, "cross_num_hidden_layers")
self.cross = CrossBertModel(cross_config)
self.cross = CrossModel(cross_config)
# <=== End of Cross Encoder
if self.train_sim_after_cross is False:
if self.task_config.do_pretrain and self.pretrain_without_decoder:
show_log(task_config, "Pretrain without decoder.")
else:
decoder_config = update_attr("decoder_config", decoder_config, "num_decoder_layers",
self.task_config, "decoder_num_hidden_layers")
self.decoder = DecoderBertModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight)
# Decoder ===>
decoder_config = update_attr("decoder_config", decoder_config, "num_decoder_layers",
self.task_config, "decoder_num_hidden_layers")
self.decoder = DecoderModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight)
# <=== End of Decoder
self.cls = BertOnlyMLMHead(bert_config, bert_word_embeddings_weight)
self.cls_visual = VisualOnlyMLMHead(visual_config, visual_word_embeddings_weight)
self.similarity_dense = nn.Linear(bert_config.hidden_size, 1)
self.alm_loss_fct = CrossEntropyLoss(ignore_index=-1)
self.decoder_loss_fct = CrossEntropyLoss(ignore_index=-1)
self.normalize_video = NormalizeVideo(task_config)
@ -350,36 +174,15 @@ class VLBert(PreTrainedModel):
batch_size=task_config.batch_size // task_config.n_gpu,
n_pair=task_config.n_pair,
hard_negative_rate=task_config.hard_negative_rate, )
if task_config.use_mil:
self.loss_fct = mILNCELoss
show_log(task_config, "Using MILNCE.")
self.loss_fct = CrossEn() if self._stage_two else mILNCELoss
self._pretrain_sim_loss_fct = mILNCELoss
else:
self.loss_fct = maxMarginRankingLoss
show_log(task_config, "Using Ranking Loss.")
self.loss_fct = CrossEn() if self._stage_two else maxMarginRankingLoss
self._pretrain_sim_loss_fct = maxMarginRankingLoss
if self._cross_model:
self.loss_fct = CrossEn()
show_log(task_config, "Use CE Loss.")
if self._choice_sim is False or self.train_sim_after_cross:
self.similarity_dense = nn.Linear(bert_config.hidden_size, 1)
self.alm_loss_fct = CrossEntropyLoss(ignore_index=-1)
self.decoder_loss_fct = CrossEntropyLoss(ignore_index=-1)
self.pretrain_with_joint_sim = False
if hasattr(self.task_config, 'pretrain_with_joint_sim'):
if self.task_config.pretrain_with_joint_sim:
self.pretrain_with_joint_sim = True
show_log(task_config, "Set pretrain with joint embedding.")
if task_config.use_mil:
self._pretrain_sim_loss_fct = mILNCELoss
show_log(task_config, "Pretrain with joint embedding using MILNCE.")
else:
self._pretrain_sim_loss_fct = maxMarginRankingLoss
show_log(task_config, "Pretrain with joint embedding using Ranking Loss.")
self.apply(self.init_bert_weights)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, video, video_mask=None,
pairs_masked_text=None, pairs_token_labels=None, masked_video=None, video_labels_index=None,
@ -395,17 +198,18 @@ class VLBert(PreTrainedModel):
input_caption_ids = input_caption_ids.view(-1, input_caption_ids.shape[-1])
decoder_mask = decoder_mask.view(-1, decoder_mask.shape[-1])
sequence_output, visual_output = self.get_sequence_visual_output(input_ids, token_type_ids,
attention_mask, video, video_mask, shaped=True)
sequence_output, visual_output = self.get_sequence_visual_output(input_ids, token_type_ids, attention_mask,
video, video_mask, shaped=True)
if self.training:
loss = 0.
if self._choice_sim:
sim_matrix = self.get_similarity_logits(sequence_output, visual_output, attention_mask, video_mask, shaped=True)
if self._stage_one:
sim_matrix = self.get_similarity_logits(sequence_output, visual_output, attention_mask,
video_mask, shaped=True)
sim_loss = self.loss_fct(sim_matrix)
loss = loss + sim_loss
loss += sim_loss
if self._cross_model and pairs_masked_text is not None and pairs_token_labels is not None:
if self._stage_two and pairs_masked_text is not None and pairs_token_labels is not None:
if self.task_config.do_pretrain:
pairs_masked_text = pairs_masked_text.view(-1, pairs_masked_text.shape[-1])
@ -414,25 +218,25 @@ class VLBert(PreTrainedModel):
masked_video = self.normalize_video(masked_video)
video_labels_index = video_labels_index.view(-1, video_labels_index.shape[-1])
sequence_output_alm, visual_output_alm = self.get_sequence_visual_output(pairs_masked_text, token_type_ids, attention_mask, masked_video, video_mask, shaped=True)
sequence_output_alm, visual_output_alm = self.get_sequence_visual_output(pairs_masked_text, token_type_ids,
attention_mask, masked_video, video_mask, shaped=True)
cross_output, pooled_output, concat_mask = self._get_cross_output(sequence_output_alm, visual_output_alm, attention_mask, video_mask)
sequence_cross_output, visual_cross_output = torch.split(cross_output, [attention_mask.size(-1), video_mask.size(-1)], dim=1)
alm_loss = self._calculate_mlm_loss(sequence_cross_output, pairs_token_labels) # For mlm
loss = loss + alm_loss
alm_loss = self._calculate_mlm_loss(sequence_cross_output, pairs_token_labels)
loss += alm_loss
nce_loss = self._calculate_mfm_loss(visual_cross_output, video, video_mask, video_labels_index) # For mfm
nce_loss = self._calculate_mfm_loss(visual_cross_output, video, video_mask, video_labels_index)
loss += nce_loss
if self.pretrain_with_joint_sim:
sim_matrix = self.get_similarity_logits(sequence_output, visual_output, attention_mask, video_mask,
shaped=True, _pretrain_joint=True)
sim_loss_joint = self._pretrain_sim_loss_fct(sim_matrix)
loss = loss + sim_loss_joint
sim_matrix = self.get_similarity_logits(sequence_output, visual_output, attention_mask, video_mask,
shaped=True, _pretrain_joint=True)
sim_loss_joint = self._pretrain_sim_loss_fct(sim_matrix)
loss += sim_loss_joint
if (input_caption_ids is not None) and \
((self.task_config.do_pretrain and self.pretrain_without_decoder is False)
(self.task_config.do_pretrain
or (self.task_config.do_pretrain is False and self.task_config.task_type == "caption")):
if self.task_config.do_pretrain:
decoder_scores, res_tuples = self._get_decoder_score(sequence_output_alm, visual_output_alm,
@ -442,21 +246,25 @@ class VLBert(PreTrainedModel):
decoder_scores, res_tuples = self._get_decoder_score(sequence_output, visual_output,
input_ids, attention_mask, video_mask,
input_caption_ids, decoder_mask, shaped=True)
else:
raise NotImplementedError
output_caption_ids = output_caption_ids.view(-1, output_caption_ids.shape[-1])
decoder_loss = self.decoder_loss_fct(decoder_scores.view(-1, self.bert_config.vocab_size), output_caption_ids.view(-1))
loss = loss + decoder_loss
loss += decoder_loss
if (self.with_sim_in_decoder and self.task_config.do_pretrain) or \
self.task_config.task_type == "retrieval":
if self.task_config.do_pretrain or self.task_config.task_type == "retrieval":
if self.task_config.do_pretrain:
sim_matrix_text_visual = self.get_similarity_logits(sequence_output_alm, visual_output_alm, attention_mask,
video_mask, shaped=True)
sim_matrix_text_visual = self.get_similarity_logits(sequence_output_alm, visual_output_alm,
attention_mask, video_mask, shaped=True)
elif self.task_config.task_type == "retrieval":
sim_matrix_text_visual = self.get_similarity_logits(sequence_output, visual_output, attention_mask,
video_mask, shaped=True)
sim_matrix_text_visual = self.get_similarity_logits(sequence_output, visual_output,
attention_mask, video_mask, shaped=True)
else:
raise NotImplementedError
sim_loss_text_visual = self.loss_fct(sim_matrix_text_visual)
loss = loss + sim_loss_text_visual
loss += sim_loss_text_visual
return loss
else:
@ -488,7 +296,6 @@ class VLBert(PreTrainedModel):
nce_loss = nce_loss.mean()
return nce_loss
# For inference
def get_sequence_visual_output(self, input_ids, token_type_ids, attention_mask, video, video_mask, shaped=False):
if shaped is False:
input_ids = input_ids.view(-1, input_ids.shape[-1])
@ -497,17 +304,26 @@ class VLBert(PreTrainedModel):
video_mask = video_mask.view(-1, video_mask.shape[-1])
video = self.normalize_video(video)
encoded_layers, _ = self.bert("bi", input_ids, token_type_ids, attention_mask, output_all_encoded_layers=True)
encoded_layers, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=True)
sequence_output = encoded_layers[-1]
visual_layers, _ = self.visual("bi", video, video_mask, output_all_encoded_layers=True)
visual_layers, _ = self.visual(video, video_mask, output_all_encoded_layers=True)
visual_output = visual_layers[-1]
# cannot add any cross code here
# beacause sequence and visual features are exclusive
return sequence_output, visual_output
def _get_cross_output(self, sequence_output, visual_output, attention_mask, video_mask):
concat_features = torch.cat((sequence_output, visual_output), dim=1) # concatnate tokens and frames
concat_mask = torch.cat((attention_mask, video_mask), dim=1)
text_type_ = torch.zeros_like(attention_mask)
video_type_ = torch.ones_like(video_mask)
concat_type = torch.cat((text_type_, video_type_), dim=1)
cross_layers, pooled_output = self.cross(concat_features, concat_type, concat_mask, output_all_encoded_layers=True)
cross_output = cross_layers[-1]
return cross_output, pooled_output, concat_mask
def _mean_pooling_for_similarity(self, sequence_output, visual_output, attention_mask, video_mask,):
attention_mask_un = attention_mask.to(dtype=torch.float).unsqueeze(-1)
attention_mask_un[:, 0, :] = 0.
@ -558,13 +374,12 @@ class VLBert(PreTrainedModel):
retrieve_logits = torch.cat(retrieve_logits_list, dim=0)
return retrieve_logits
# For inference
def get_similarity_logits(self, sequence_output, visual_output, attention_mask, video_mask, shaped=False, _pretrain_joint=False):
if shaped is False:
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
video_mask = video_mask.view(-1, video_mask.shape[-1])
if self._cross_model and _pretrain_joint is False:
if self._stage_two and _pretrain_joint is False:
retrieve_logits = self._cross_similarity(sequence_output, visual_output, attention_mask, video_mask)
else:
if self.train_sim_after_cross:
@ -579,19 +394,6 @@ class VLBert(PreTrainedModel):
return retrieve_logits
def _get_cross_output(self, sequence_output, visual_output, attention_mask, video_mask):
# Generate one pair (x, y) ===>
concat_features = torch.cat((sequence_output, visual_output), dim=1) # concatnate tokens and frames
concat_mask = torch.cat((attention_mask, video_mask), dim=1)
text_type_ = torch.zeros_like(attention_mask)
video_type_ = torch.ones_like(video_mask)
concat_type = torch.cat((text_type_, video_type_), dim=1)
# <===
cross_layers, pooled_output = self.cross("bi", concat_features, concat_type, concat_mask, output_all_encoded_layers=True)
cross_output = cross_layers[-1]
return cross_output, pooled_output, concat_mask
def _get_decoder_score(self, sequence_output, visual_output, input_ids, attention_mask, video_mask, input_caption_ids, decoder_mask, shaped=False):
if shaped is False:
@ -604,16 +406,12 @@ class VLBert(PreTrainedModel):
res_tuples = ()
cross_output, pooled_output, concat_mask = self._get_cross_output(sequence_output, visual_output, attention_mask, video_mask)
decoder_scores = self.decoder(input_caption_ids, input_ids, encoder_outs=cross_output, answer_mask=decoder_mask, encoder_mask=concat_mask)
decoder_scores = self.decoder(input_caption_ids, encoder_outs=cross_output, answer_mask=decoder_mask, encoder_mask=concat_mask)
return decoder_scores, res_tuples
def decoder_caption(self, sequence_output, visual_output, input_ids, attention_mask, video_mask, input_caption_ids, decoder_mask,
shaped=False, get_logits=False):
"""
:param get_logits: Used in Beam Search
:return:
"""
if shaped is False:
input_ids = input_ids.view(-1, input_ids.shape[-1])
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
@ -622,9 +420,9 @@ class VLBert(PreTrainedModel):
input_caption_ids = input_caption_ids.view(-1, input_caption_ids.shape[-1])
decoder_mask = decoder_mask.view(-1, decoder_mask.shape[-1])
assert self.pretrain_without_decoder is False
decoder_scores, _ = self._get_decoder_score(sequence_output, visual_output,
input_ids, attention_mask, video_mask, input_caption_ids, decoder_mask, shaped=True)
input_ids, attention_mask, video_mask,
input_caption_ids, decoder_mask, shaped=True)
if get_logits:
return decoder_scores

Просмотреть файл

@ -27,14 +27,13 @@ import logging
import tarfile
import tempfile
import shutil
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .until_config import PretrainedConfig
from .until_module import PreTrainedModel, LayerNorm, ACT2FN
logger = logging.getLogger(__name__)
@ -51,24 +50,14 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(object):
class BertConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `BertModel`.
"""
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
config_name = CONFIG_NAME
weights_name = WEIGHTS_NAME
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
@ -126,52 +115,6 @@ class BertConfig(object):
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
@ -183,7 +126,7 @@ class BertEmbeddings(nn.Module):
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
@ -258,7 +201,7 @@ class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
@ -297,7 +240,7 @@ class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
@ -359,7 +302,7 @@ class BertPredictionHeadTransform(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
@ -418,212 +361,7 @@ class BertPreTrainingHeads(nn.Module):
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class PreTrainedBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedBertModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def resize_token_embeddings(self, new_num_tokens=None):
model = self.module if hasattr(self, 'module') else self
base_model = getattr(model, "bert", model) # get the base model if needed
old_embeddings = base_model.embeddings.word_embeddings
# get_resized_embeddings <-----
if new_num_tokens is None:
return old_embeddings
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
if old_num_tokens == new_num_tokens:
return old_embeddings
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
device = old_embeddings.weight.device
new_embeddings.to(device)
# initialize all new embeddings (in particular added tokens)
self.init_bert_weights(new_embeddings)
# Copy word embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
# ----->
# Update base model and current model config
base_model.embeddings.word_embeddings = new_embeddings
base_model.config.vocab_size = new_num_tokens
model.config.vocab_size = new_num_tokens
model.cls = BertOnlyMLMHead(model.config, base_model.embeddings.word_embeddings.weight).to(device)
@classmethod
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
'''
abstract from `from_pretrained`
'''
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
if os.path.exists(archive_file) is False:
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
if task_config is None or task_config.local_rank == 0:
logger.error(
"Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
archive_file))
return None
if resolved_archive_file == archive_file:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {}".format(archive_file))
else:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
if task_config is None or task_config.local_rank == 0:
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = BertConfig.from_json_file(config_file)
config.type_vocab_size = type_vocab_size
if task_config is None or task_config.local_rank == 0:
logger.info("Model config {}".format(config))
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
if os.path.exists(weights_path):
state_dict = torch.load(weights_path)
else:
if task_config is None or task_config.local_rank == 0:
logger.info("Weight doesn't exsits. {}".format(weights_path))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return config, state_dict
@classmethod
def init_preweight(cls, model, state_dict):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
if len(error_msgs) > 0:
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
return model
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None,
cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_classes for BertForSequenceClassification)
"""
config, state_dict = PreTrainedBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict)
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
return model
model = PreTrainedBertModel.init_preweight(model, state_dict)
return model
class BertModel(PreTrainedBertModel):
class BertModel(PreTrainedModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
@ -673,12 +411,10 @@ class BertModel(PreTrainedBertModel):
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
self.apply(self.init_weights)
def forward(self, type, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True,
attention_sentenceA_mask=None,
attention_sentenceB_mask=None):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
@ -697,7 +433,7 @@ class BertModel(PreTrainedBertModel):
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids)
@ -708,4 +444,4 @@ class BertModel(PreTrainedBertModel):
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output
return encoded_layers, pooled_output

Просмотреть файл

@ -27,43 +27,27 @@ import logging
import tarfile
import tempfile
import shutil
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .until_config import PretrainedConfig
from .until_module import PreTrainedModel, LayerNorm, ACT2FN
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'cross-bert-base': "[TODO]",
'cross-bert-large': "[TODO]",
}
CONFIG_NAME = 'cross_bert_config.json'
PRETRAINED_MODEL_ARCHIVE_MAP = {}
CONFIG_NAME = 'cross_config.json'
WEIGHTS_NAME = 'cross_pytorch_model.bin'
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class CrossBertConfig(object):
"""Configuration class to store the configuration of a `CrossBertModel`.
class CrossConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `CrossModel`.
"""
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
config_name = CONFIG_NAME
weights_name = WEIGHTS_NAME
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
@ -76,10 +60,10 @@ class CrossBertConfig(object):
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02):
"""Constructs CrossBertConfig.
"""Constructs CrossConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossBertModel`.
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
@ -96,7 +80,7 @@ class CrossBertConfig(object):
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`CrossBertModel`.
`CrossModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
@ -121,64 +105,19 @@ class CrossBertConfig(object):
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `CrossBertConfig` from a Python dictionary of parameters."""
config = CrossBertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `CrossBertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as CrossBertLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class CrossBertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(CrossBertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class CrossBertEmbeddings(nn.Module):
class CrossEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(CrossBertEmbeddings, self).__init__()
super(CrossEmbeddings, self).__init__()
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, concat_embeddings, concat_type=None):
@ -198,9 +137,9 @@ class CrossBertEmbeddings(nn.Module):
embeddings = self.dropout(embeddings)
return embeddings
class CrossBertSelfAttention(nn.Module):
class CrossSelfAttention(nn.Module):
def __init__(self, config):
super(CrossBertSelfAttention, self).__init__()
super(CrossSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
@ -232,7 +171,7 @@ class CrossBertSelfAttention(nn.Module):
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in CrossBertModel forward() function)
# Apply the attention mask is (precomputed for all layers in CrossModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
@ -249,11 +188,11 @@ class CrossBertSelfAttention(nn.Module):
return context_layer
class CrossBertSelfOutput(nn.Module):
class CrossSelfOutput(nn.Module):
def __init__(self, config):
super(CrossBertSelfOutput, self).__init__()
super(CrossSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
@ -263,11 +202,11 @@ class CrossBertSelfOutput(nn.Module):
return hidden_states
class CrossBertAttention(nn.Module):
class CrossAttention(nn.Module):
def __init__(self, config):
super(CrossBertAttention, self).__init__()
self.self = CrossBertSelfAttention(config)
self.output = CrossBertSelfOutput(config)
super(CrossAttention, self).__init__()
self.self = CrossSelfAttention(config)
self.output = CrossSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
@ -275,9 +214,9 @@ class CrossBertAttention(nn.Module):
return attention_output
class CrossBertIntermediate(nn.Module):
class CrossIntermediate(nn.Module):
def __init__(self, config):
super(CrossBertIntermediate, self).__init__()
super(CrossIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
@ -288,11 +227,11 @@ class CrossBertIntermediate(nn.Module):
return hidden_states
class CrossBertOutput(nn.Module):
class CrossOutput(nn.Module):
def __init__(self, config):
super(CrossBertOutput, self).__init__()
super(CrossOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
@ -302,12 +241,12 @@ class CrossBertOutput(nn.Module):
return hidden_states
class CrossBertLayer(nn.Module):
class CrossLayer(nn.Module):
def __init__(self, config):
super(CrossBertLayer, self).__init__()
self.attention = CrossBertAttention(config)
self.intermediate = CrossBertIntermediate(config)
self.output = CrossBertOutput(config)
super(CrossLayer, self).__init__()
self.attention = CrossAttention(config)
self.intermediate = CrossIntermediate(config)
self.output = CrossOutput(config)
def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
@ -316,10 +255,10 @@ class CrossBertLayer(nn.Module):
return layer_output
class CrossBertEncoder(nn.Module):
class CrossEncoder(nn.Module):
def __init__(self, config):
super(CrossBertEncoder, self).__init__()
layer = CrossBertLayer(config)
super(CrossEncoder, self).__init__()
layer = CrossLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
@ -333,9 +272,9 @@ class CrossBertEncoder(nn.Module):
return all_encoder_layers
class CrossBertPooler(nn.Module):
class CrossPooler(nn.Module):
def __init__(self, config):
super(CrossBertPooler, self).__init__()
super(CrossPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
@ -348,13 +287,13 @@ class CrossBertPooler(nn.Module):
return pooled_output
class CrossBertPredictionHeadTransform(nn.Module):
class CrossPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(CrossBertPredictionHeadTransform, self).__init__()
super(CrossPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
@ -363,18 +302,18 @@ class CrossBertPredictionHeadTransform(nn.Module):
return hidden_states
class CrossBertLMPredictionHead(nn.Module):
def __init__(self, config, cross_bert_model_embedding_weights):
super(CrossBertLMPredictionHead, self).__init__()
self.transform = CrossBertPredictionHeadTransform(config)
class CrossLMPredictionHead(nn.Module):
def __init__(self, config, cross_model_embedding_weights):
super(CrossLMPredictionHead, self).__init__()
self.transform = CrossPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(cross_bert_model_embedding_weights.size(1),
cross_bert_model_embedding_weights.size(0),
self.decoder = nn.Linear(cross_model_embedding_weights.size(1),
cross_model_embedding_weights.size(0),
bias=False)
self.decoder.weight = cross_bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(cross_bert_model_embedding_weights.size(0)))
self.decoder.weight = cross_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(cross_model_embedding_weights.size(0)))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
@ -382,19 +321,19 @@ class CrossBertLMPredictionHead(nn.Module):
return hidden_states
class CrossBertOnlyMLMHead(nn.Module):
def __init__(self, config, cross_bert_model_embedding_weights):
super(CrossBertOnlyMLMHead, self).__init__()
self.predictions = CrossBertLMPredictionHead(config, cross_bert_model_embedding_weights)
class CrossOnlyMLMHead(nn.Module):
def __init__(self, config, cross_model_embedding_weights):
super(CrossOnlyMLMHead, self).__init__()
self.predictions = CrossLMPredictionHead(config, cross_model_embedding_weights)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class CrossBertOnlyNSPHead(nn.Module):
class CrossOnlyNSPHead(nn.Module):
def __init__(self, config):
super(CrossBertOnlyNSPHead, self).__init__()
super(CrossOnlyNSPHead, self).__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
@ -402,10 +341,10 @@ class CrossBertOnlyNSPHead(nn.Module):
return seq_relationship_score
class CrossBertPreTrainingHeads(nn.Module):
def __init__(self, config, cross_bert_model_embedding_weights):
super(CrossBertPreTrainingHeads, self).__init__()
self.predictions = CrossBertLMPredictionHead(config, cross_bert_model_embedding_weights)
class CrossPreTrainingHeads(nn.Module):
def __init__(self, config, cross_model_embedding_weights):
super(CrossPreTrainingHeads, self).__init__()
self.predictions = CrossLMPredictionHead(config, cross_model_embedding_weights)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
@ -413,185 +352,16 @@ class CrossBertPreTrainingHeads(nn.Module):
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class PreTrainedCrossBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedCrossBertModel, self).__init__()
if not isinstance(config, CrossBertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `CrossBertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_cross_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, CrossBertLayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
'''
abstract from `from_pretrained`
'''
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
if os.path.exists(archive_file) is False:
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
if task_config is None or task_config.local_rank == 0:
logger.error(
"Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
archive_file))
return None
if resolved_archive_file == archive_file:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {}".format(archive_file))
else:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
if task_config is None or task_config.local_rank == 0:
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = CrossBertConfig.from_json_file(config_file)
config.type_vocab_size = type_vocab_size
if task_config is None or task_config.local_rank == 0:
logger.info("Model config {}".format(config))
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
if os.path.exists(weights_path):
state_dict = torch.load(weights_path)
else:
if task_config is None or task_config.local_rank == 0:
logger.info("Weight doesn't exsits. {}".format(weights_path))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return config, state_dict
@classmethod
def init_preweight(cls, model, state_dict):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'cross_bert') else 'cross_bert.')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
if len(error_msgs) > 0:
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
return model
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None,
cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
"""
Instantiate a PreTrainedCrossBertModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `cross_bert-base`
. `cross_bert-large`
- a path or url to a pretrained model archive containing:
. `cross-bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a CrossBertForPreTraining instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific CrossBert class
(ex: num_classes for CrossBertForSequenceClassification)
"""
config, state_dict = PreTrainedCrossBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict)
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
return model
model = PreTrainedCrossBertModel.init_preweight(model, state_dict)
return model
class CrossBertModel(PreTrainedCrossBertModel):
class CrossModel(PreTrainedModel):
def __init__(self, config):
super(CrossBertModel, self).__init__(config)
self.embeddings = CrossBertEmbeddings(config)
self.encoder = CrossBertEncoder(config)
self.pooler = CrossBertPooler(config)
self.apply(self.init_cross_bert_weights)
super(CrossModel, self).__init__(config)
self.embeddings = CrossEmbeddings(config)
self.encoder = CrossEncoder(config)
self.pooler = CrossPooler(config)
self.apply(self.init_weights)
def forward(self, type, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True,
attention_sentenceA_mask=None,
attention_sentenceB_mask=None):
def forward(self, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True):
if attention_mask is None:
attention_mask = torch.ones(concat_input.size(0), concat_input.size(1))
@ -610,7 +380,7 @@ class CrossBertModel(PreTrainedCrossBertModel):
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(concat_input, concat_type)

406
modules/module_decoder.py Normal file
Просмотреть файл

@ -0,0 +1,406 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import numpy as np
import torch
from torch import nn
from .file_utils import cached_path
from .until_config import PretrainedConfig
from .until_module import PreTrainedModel, LayerNorm, ACT2FN
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {}
CONFIG_NAME = 'decoder_config.json'
WEIGHTS_NAME = 'decoder_pytorch_model.bin'
class DecoderConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `DecoderModel`.
"""
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
config_name = CONFIG_NAME
weights_name = WEIGHTS_NAME
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_vocab_size=2,
initializer_range=0.02,
max_target_embeddings=128,
num_decoder_layers=1):
"""Constructs DecoderConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `DecoderModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`DecoderModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
max_target_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
num_decoder_layers:
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.max_target_embeddings = max_target_embeddings
self.num_decoder_layers = num_decoder_layers
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config, decoder_model_embedding_weights):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(decoder_model_embedding_weights.size(1),
decoder_model_embedding_weights.size(0),
bias=False)
self.decoder.weight = decoder_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(decoder_model_embedding_weights.size(0)))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config, decoder_model_embedding_weights):
super(BertOnlyMLMHead, self).__init__()
self.predictions = BertLMPredictionHead(config, decoder_model_embedding_weights)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, config):
super(MultiHeadAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, q, k, v, attention_mask):
mixed_query_layer = self.query(q)
mixed_key_layer = self.key(k)
mixed_value_layer = self.value(v)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer, attention_scores
class PositionwiseFeedForward(nn.Module):
''' A two-feed-forward-layer module '''
def __init__(self, d_in, d_hid, dropout=0.1):
super().__init__()
self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
self.layer_norm = nn.LayerNorm(d_in)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
output = x.transpose(1, 2)
output = self.w_2(ACT2FN["gelu"](self.w_1(output)))
output = output.transpose(1, 2)
output = self.dropout(output)
output = self.layer_norm(output + residual)
return output
class DecoderAttention(nn.Module):
def __init__(self, config):
super(DecoderAttention, self).__init__()
self.att = MultiHeadAttention(config)
self.output = BertSelfOutput(config)
def forward(self, q, k, v, attention_mask):
att_output, attention_probs = self.att(q, k, v, attention_mask)
attention_output = self.output(att_output, q)
return attention_output, attention_probs
class DecoderLayer(nn.Module):
def __init__(self, config):
super(DecoderLayer, self).__init__()
self.slf_attn = DecoderAttention(config)
self.enc_attn = DecoderAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None):
slf_output, _ = self.slf_attn(dec_input, dec_input, dec_input, slf_attn_mask)
dec_output, dec_att_scores = self.enc_attn(slf_output, enc_output, enc_output, dec_enc_attn_mask)
intermediate_output = self.intermediate(dec_output)
dec_output = self.output(intermediate_output, dec_output)
return dec_output, dec_att_scores
class DecoderEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config, decoder_word_embeddings_weight, decoder_position_embeddings_weight):
super(DecoderEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_target_embeddings, config.hidden_size)
self.word_embeddings.weight = decoder_word_embeddings_weight
self.position_embeddings.weight = decoder_position_embeddings_weight
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class Decoder(nn.Module):
def __init__(self, config):
super(Decoder, self).__init__()
layer = DecoderLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_decoder_layers)])
def forward(self, hidden_states, encoder_outs, self_attn_mask, attention_mask, output_all_encoded_layers=False):
dec_att_scores = None
all_encoder_layers = []
all_dec_att_probs = []
for layer_module in self.layer:
hidden_states, dec_att_scores = layer_module(hidden_states, encoder_outs, self_attn_mask, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
all_dec_att_probs.append(dec_att_scores)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
all_dec_att_probs.append(dec_att_scores)
return all_encoder_layers, all_dec_att_probs
class DecoderClassifier(nn.Module):
def __init__(self, config, embedding_weights):
super(DecoderClassifier, self).__init__()
self.cls = BertOnlyMLMHead(config, embedding_weights)
def forward(self, hidden_states):
cls_scores = self.cls(hidden_states)
return cls_scores
class DecoderModel(PreTrainedModel):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
final_norm (bool, optional): apply layer norm to the output of the
final decoder layer (default: True).
"""
def __init__(self, config, decoder_word_embeddings_weight, decoder_position_embeddings_weight):
super(DecoderModel, self).__init__(config)
self.config = config
self.max_target_length = config.max_target_embeddings
self.embeddings = DecoderEmbeddings(config, decoder_word_embeddings_weight, decoder_position_embeddings_weight)
self.decoder = Decoder(config)
self.classifier = DecoderClassifier(config, decoder_word_embeddings_weight)
self.apply(self.init_weights)
def forward(self, input_ids, encoder_outs=None, answer_mask=None, encoder_mask=None):
"""
Args:
input_ids (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing
encoder_outs (Tensor, optional): output from the encoder, used for encoder-side attention
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len, vocab)`
- the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)`
"""
embedding_output = self.embeddings(input_ids)
extended_encoder_mask = encoder_mask.unsqueeze(1).unsqueeze(2) # b x 1 x 1 x ls
extended_encoder_mask = extended_encoder_mask.to(dtype=self.dtype) # fp16 compatibility
extended_encoder_mask = (1.0 - extended_encoder_mask) * -10000.0
extended_answer_mask = answer_mask.unsqueeze(1).unsqueeze(2)
extended_answer_mask = extended_answer_mask.to(dtype=self.dtype) # fp16 compatibility
sz_b, len_s, _ = embedding_output.size()
subsequent_mask = torch.triu(torch.ones((len_s, len_s), device=embedding_output.device, dtype=embedding_output.dtype), diagonal=1)
self_attn_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1).unsqueeze(1) # b x 1 x ls x ls
slf_attn_mask = ((1.0 - extended_answer_mask) + self_attn_mask).gt(0).to(dtype=self.dtype)
self_attn_mask = slf_attn_mask * -10000.0
decoded_layers, dec_att_scores = self.decoder(embedding_output,
encoder_outs,
self_attn_mask,
extended_encoder_mask,
)
sequence_output = decoded_layers[-1]
cls_scores = self.classifier(sequence_output)
return cls_scores

Просмотреть файл

@ -27,43 +27,27 @@ import logging
import tarfile
import tempfile
import shutil
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
from .until_config import PretrainedConfig
from .until_module import PreTrainedModel, LayerNorm, ACT2FN
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'visual-bert-base': "[TODO]",
'visual-bert-large': "[TODO]",
}
CONFIG_NAME = 'visual_bert_config.json'
PRETRAINED_MODEL_ARCHIVE_MAP = {}
CONFIG_NAME = 'visual_config.json'
WEIGHTS_NAME = 'visual_pytorch_model.bin'
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class VisualBertConfig(object):
"""Configuration class to store the configuration of a `VisualBertModel`.
class VisualConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `VisualModel`.
"""
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
config_name = CONFIG_NAME
weights_name = WEIGHTS_NAME
def __init__(self,
vocab_size_or_config_json_file=4096,
hidden_size=768,
@ -75,7 +59,7 @@ class VisualBertConfig(object):
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
initializer_range=0.02):
"""Constructs VisualBertConfig.
"""Constructs VisualConfig.
Args:
vocab_size_or_config_json_file: Size of the encoder layers and the pooler layer.
@ -117,67 +101,18 @@ class VisualBertConfig(object):
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `VisualBertConfig` from a Python dictionary of parameters."""
config = VisualBertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `VisualBertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as VisualBertLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class VisualBertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(VisualBertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class VisualBertEmbeddings(nn.Module):
class VisualEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(VisualBertEmbeddings, self).__init__()
super(VisualEmbeddings, self).__init__()
self.word_embeddings = nn.Linear(config.vocab_size, config.hidden_size)
# self.transform_act_fn = ACT2FN[config.hidden_act] \
# if isinstance(config.hidden_act, str) else config.hidden_act
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_embeddings):
@ -195,10 +130,9 @@ class VisualBertEmbeddings(nn.Module):
embeddings = self.dropout(embeddings)
return embeddings
class VisualBertSelfAttention(nn.Module):
class VisualSelfAttention(nn.Module):
def __init__(self, config):
super(VisualBertSelfAttention, self).__init__()
super(VisualSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
@ -230,7 +164,7 @@ class VisualBertSelfAttention(nn.Module):
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)
# Apply the attention mask is (precomputed for all layers in VisualModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
@ -247,11 +181,11 @@ class VisualBertSelfAttention(nn.Module):
return context_layer
class VisualBertSelfOutput(nn.Module):
class VisualSelfOutput(nn.Module):
def __init__(self, config):
super(VisualBertSelfOutput, self).__init__()
super(VisualSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
@ -261,11 +195,11 @@ class VisualBertSelfOutput(nn.Module):
return hidden_states
class VisualBertAttention(nn.Module):
class VisualAttention(nn.Module):
def __init__(self, config):
super(VisualBertAttention, self).__init__()
self.self = VisualBertSelfAttention(config)
self.output = VisualBertSelfOutput(config)
super(VisualAttention, self).__init__()
self.self = VisualSelfAttention(config)
self.output = VisualSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
@ -273,9 +207,9 @@ class VisualBertAttention(nn.Module):
return attention_output
class VisualBertIntermediate(nn.Module):
class VisualIntermediate(nn.Module):
def __init__(self, config):
super(VisualBertIntermediate, self).__init__()
super(VisualIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
@ -286,11 +220,11 @@ class VisualBertIntermediate(nn.Module):
return hidden_states
class VisualBertOutput(nn.Module):
class VisualOutput(nn.Module):
def __init__(self, config):
super(VisualBertOutput, self).__init__()
super(VisualOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
@ -300,12 +234,12 @@ class VisualBertOutput(nn.Module):
return hidden_states
class VisualBertLayer(nn.Module):
class VisualLayer(nn.Module):
def __init__(self, config):
super(VisualBertLayer, self).__init__()
self.attention = VisualBertAttention(config)
self.intermediate = VisualBertIntermediate(config)
self.output = VisualBertOutput(config)
super(VisualLayer, self).__init__()
self.attention = VisualAttention(config)
self.intermediate = VisualIntermediate(config)
self.output = VisualOutput(config)
def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
@ -314,10 +248,10 @@ class VisualBertLayer(nn.Module):
return layer_output
class VisualBertEncoder(nn.Module):
class VisualEncoder(nn.Module):
def __init__(self, config):
super(VisualBertEncoder, self).__init__()
layer = VisualBertLayer(config)
super(VisualEncoder, self).__init__()
layer = VisualLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
@ -331,9 +265,9 @@ class VisualBertEncoder(nn.Module):
return all_encoder_layers
class VisualBertPooler(nn.Module):
class VisualPooler(nn.Module):
def __init__(self, config):
super(VisualBertPooler, self).__init__()
super(VisualPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
@ -346,13 +280,13 @@ class VisualBertPooler(nn.Module):
return pooled_output
class VisualBertPredictionHeadTransform(nn.Module):
class VisualPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(VisualBertPredictionHeadTransform, self).__init__()
super(VisualPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
@ -361,18 +295,15 @@ class VisualBertPredictionHeadTransform(nn.Module):
return hidden_states
class VisualBertLMPredictionHead(nn.Module):
"""
Note: It is a trap for `visual_bert_model_embedding_weights`, which is different from Embedding of BERT
"""
def __init__(self, config, visual_bert_model_embedding_weights):
super(VisualBertLMPredictionHead, self).__init__()
self.transform = VisualBertPredictionHeadTransform(config)
class VisualLMPredictionHead(nn.Module):
def __init__(self, config, visual_model_embedding_weights):
super(VisualLMPredictionHead, self).__init__()
self.transform = VisualPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.weight = visual_bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(visual_bert_model_embedding_weights.size(1)))
self.weight = visual_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(visual_model_embedding_weights.size(1)))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
@ -380,19 +311,19 @@ class VisualBertLMPredictionHead(nn.Module):
return hidden_states
class VisualBertOnlyMLMHead(nn.Module):
def __init__(self, config, visual_bert_model_embedding_weights):
super(VisualBertOnlyMLMHead, self).__init__()
self.predictions = VisualBertLMPredictionHead(config, visual_bert_model_embedding_weights)
class VisualOnlyMLMHead(nn.Module):
def __init__(self, config, visual_model_embedding_weights):
super(VisualOnlyMLMHead, self).__init__()
self.predictions = VisualLMPredictionHead(config, visual_model_embedding_weights)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class VisualBertOnlyNSPHead(nn.Module):
class VisualOnlyNSPHead(nn.Module):
def __init__(self, config):
super(VisualBertOnlyNSPHead, self).__init__()
super(VisualOnlyNSPHead, self).__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
@ -400,10 +331,10 @@ class VisualBertOnlyNSPHead(nn.Module):
return seq_relationship_score
class VisualBertPreTrainingHeads(nn.Module):
def __init__(self, config, visual_bert_model_embedding_weights):
super(VisualBertPreTrainingHeads, self).__init__()
self.predictions = VisualBertLMPredictionHead(config, visual_bert_model_embedding_weights)
class VisualPreTrainingHeads(nn.Module):
def __init__(self, config, visual_model_embedding_weights):
super(VisualPreTrainingHeads, self).__init__()
self.predictions = VisualLMPredictionHead(config, visual_model_embedding_weights)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
@ -412,178 +343,11 @@ class VisualBertPreTrainingHeads(nn.Module):
return prediction_scores, seq_relationship_score
class PreTrainedVisualBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedVisualBertModel, self).__init__()
if not isinstance(config, VisualBertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `VisualBertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_visual_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, VisualBertLayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def get_config(cls, pretrained_model_name, cache_dir, state_dict, task_config=None):
'''
abstract from `from_pretrained`
'''
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
if os.path.exists(archive_file) is False:
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
if task_config is None or task_config.local_rank == 0:
logger.error(
"Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
archive_file))
return None
if resolved_archive_file == archive_file:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {}".format(archive_file))
else:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
if task_config is None or task_config.local_rank == 0:
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = VisualBertConfig.from_json_file(config_file)
if task_config is None or task_config.local_rank == 0:
logger.info("Model config {}".format(config))
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
if os.path.exists(weights_path):
state_dict = torch.load(weights_path)
else:
if task_config is None or task_config.local_rank == 0:
logger.info("Weight doesn't exsits. {}".format(weights_path))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return config, state_dict
@classmethod
def init_preweight(cls, model, state_dict):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'visual_bert') else 'visual_bert.')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
if len(error_msgs) > 0:
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
return model
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None,
cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedVisualBertModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `visual-bert-base`
. `visual-bert-large`
- a path or url to a pretrained model archive containing:
. `visual_bert_config.json` a configuration file for the model
. `visual_pytorch_model.bin` a PyTorch dump of a VisualBertForPreTraining instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific VisualBert class
(ex: num_classes for VisualBertForSequenceClassification)
"""
config, state_dict = PreTrainedVisualBertModel.get_config(pretrained_model_name, cache_dir, state_dict)
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
return model
model = PreTrainedVisualBertModel.init_preweight(model, state_dict)
return model
class VisualBertModel(PreTrainedVisualBertModel):
"""VisualBert model ("Bidirectional Embedding Representations from a Transformer").
class VisualModel(PreTrainedModel):
"""Visual model ("Bidirectional Embedding Representations from a Transformer").
Params:
config: a VisualBertConfig class instance with the configuration to build a new model
config: a VisualConfig class instance with the configuration to build a new model
Inputs:
`type`: a str, indicates which masking will be used in the attention, choice from [`bi`, `seq`, `gen`]
@ -592,7 +356,7 @@ class VisualBertModel(PreTrainedVisualBertModel):
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see Bert paper for more details).
a `sentence B` token (see paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
@ -602,13 +366,13 @@ class VisualBertModel(PreTrainedVisualBertModel):
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for VisualBert-base, 24 for VisualBert-large), each
of each attention block (i.e. 12 full sequences for Visual-base, 24 for Visual-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see Bert's paper).
input (`CLF`) to train on the Next-Sentence task (see 's paper).
Example usage:
```python
@ -616,22 +380,21 @@ class VisualBertModel(PreTrainedVisualBertModel):
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
config = modeling.VisualBertConfig(vocab_size_or_config_json_file=4096, hidden_size=768,
config = modeling.VisualConfig(vocab_size_or_config_json_file=4096, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.VisualBertModel(config=config)
model = modeling.VisualModel(config=config)
all_encoder_layers, pooled_output = model(video, video_mask)
```
"""
def __init__(self, config):
super(VisualBertModel, self).__init__(config)
self.embeddings = VisualBertEmbeddings(config)
self.encoder = VisualBertEncoder(config)
self.pooler = VisualBertPooler(config)
self.apply(self.init_visual_bert_weights)
super(VisualModel, self).__init__(config)
self.embeddings = VisualEmbeddings(config)
self.encoder = VisualEncoder(config)
self.pooler = VisualPooler(config)
self.apply(self.init_weights)
def forward(self, type, video, attention_mask=None, output_all_encoded_layers=True,
attention_sentenceA_mask=None, attention_sentenceB_mask=None):
def forward(self, video, attention_mask=None, output_all_encoded_layers=True):
if attention_mask is None:
attention_mask = torch.ones(video.size(0), video.size(1))
@ -648,7 +411,7 @@ class VisualBertModel(PreTrainedVisualBertModel):
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(video)

Просмотреть файл

@ -110,8 +110,6 @@ class BertAdam(Optimizer):
if closure is not None:
loss = closure()
warned_for_t_total = False
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
@ -139,8 +137,10 @@ class BertAdam(Optimizer):
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
next_m.mul_(beta1).add_(1 - beta1, grad)
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
# next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7
next_m.mul_(beta1).add_(grad, alpha=1 - beta1)
# next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7
next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
update = next_m / (next_v.sqrt() + group['e'])
# Just adding the square of the weights to the loss function is *not*
@ -157,13 +157,6 @@ class BertAdam(Optimizer):
schedule_fct = SCHEDULES[group['schedule']]
progress = state['step']/group['t_total']
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
# warning for exceeding t_total (only active with warmup_linear
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
logger.warning(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
warned_for_t_total = True
# end warning
else:
lr_scheduled = group['lr']
@ -172,9 +165,4 @@ class BertAdam(Optimizer):
state['step'] += 1
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
# No bias correction
# bias_correction1 = 1 - beta1 ** state['step']
# bias_correction2 = 1 - beta2 ** state['step']
return loss

Просмотреть файл

126
modules/until_config.py Normal file
Просмотреть файл

@ -0,0 +1,126 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import json
import logging
import tarfile
import tempfile
import shutil
import torch
from .file_utils import cached_path
logger = logging.getLogger(__name__)
class PretrainedConfig(object):
pretrained_model_archive_map = {}
config_name = ""
weights_name = ""
@classmethod
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
if os.path.exists(archive_file) is False:
if pretrained_model_name in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
if task_config is None or task_config.local_rank == 0:
logger.error(
"Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
archive_file))
return None
if resolved_archive_file == archive_file:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {}".format(archive_file))
else:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
if task_config is None or task_config.local_rank == 0:
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, cls.config_name)
config = cls.from_json_file(config_file)
config.type_vocab_size = type_vocab_size
if task_config is None or task_config.local_rank == 0:
logger.info("Model config {}".format(config))
if state_dict is None:
weights_path = os.path.join(serialization_dir, cls.weights_name)
if os.path.exists(weights_path):
state_dict = torch.load(weights_path, map_location='cpu')
else:
if task_config is None or task_config.local_rank == 0:
logger.info("Weight doesn't exsits. {}".format(weights_path))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return config, state_dict
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = cls(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

251
modules/until_module.py Normal file
Просмотреть файл

@ -0,0 +1,251 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
import logging
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import math
from modules.until_config import PretrainedConfig
logger = logging.getLogger(__name__)
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class PreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
if not isinstance(config, PretrainedConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, LayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def resize_token_embeddings(self, new_num_tokens=None):
raise NotImplementedError
@classmethod
def init_preweight(cls, model, state_dict, prefix=None, task_config=None):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
if prefix is not None:
old_keys = []
new_keys = []
for key in state_dict.keys():
old_keys.append(key)
new_keys.append(prefix + key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='')
if prefix is None and (task_config is None or task_config.local_rank == 0):
logger.info("-" * 20)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}"
.format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}"
.format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
if len(error_msgs) > 0:
logger.error("Weights from pretrained model cause errors in {}: {}"
.format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
return model
@property
def dtype(self):
"""
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
try:
return next(self.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module):
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
@classmethod
def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
"""
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
return model
model = cls.init_preweight(model, state_dict)
return model
##################################
###### LOSS FUNCTION #############
##################################
class CrossEn(nn.Module):
def __init__(self,):
super(CrossEn, self).__init__()
def forward(self, sim_matrix):
logpt = F.log_softmax(sim_matrix, dim=-1)
logpt = torch.diag(logpt)
nce_loss = -logpt
sim_loss = nce_loss.mean()
return sim_loss
class MILNCELoss(nn.Module):
def __init__(self, batch_size=1, n_pair=1,):
super(MILNCELoss, self).__init__()
self.batch_size = batch_size
self.n_pair = n_pair
torch_v = float(".".join(torch.__version__.split(".")[:2]))
self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8
def forward(self, sim_matrix):
mm_mask = np.eye(self.batch_size)
mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair)))
mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device)
from_text_matrix = sim_matrix + mm_mask * -1e12
from_video_matrix = sim_matrix.transpose(1, 0)
new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1)
logpt = F.log_softmax(new_sim_matrix, dim=-1)
mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1)
masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12
new_logpt = -torch.logsumexp(masked_logpt, dim=-1)
logpt_choice = torch.zeros_like(new_logpt)
mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2)
logpt_choice[mark_ind] = 1
sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean()
return sim_loss
class MaxMarginRankingLoss(nn.Module):
def __init__(self,
margin=1.0,
negative_weighting=False,
batch_size=1,
n_pair=1,
hard_negative_rate=0.5,
):
super(MaxMarginRankingLoss, self).__init__()
self.margin = margin
self.n_pair = n_pair
self.batch_size = batch_size
easy_negative_rate = 1 - hard_negative_rate
self.easy_negative_rate = easy_negative_rate
self.negative_weighting = negative_weighting
if n_pair > 1 and batch_size > 1:
alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate))
mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha
mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair)))
mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate))
self.mm_mask = mm_mask.float()
def forward(self, x):
d = torch.diag(x)
max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \
F.relu(self.margin + x - d.view(1, -1))
if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1:
max_margin = max_margin * self.mm_mask.to(max_margin.device)
return max_margin.mean()

Просмотреть файл

@ -1,713 +0,0 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'decoder-bert-base': "[TODO]",
'decoder-bert-large': "[TODO]",
}
CONFIG_NAME = 'decoder_bert_config.json'
WEIGHTS_NAME = 'decoder_pytorch_model.bin'
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_vocab_size=2,
initializer_range=0.02,
max_target_embeddings=128,
num_decoder_layers=1):
"""Constructs BertConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
max_target_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
num_decoder_layers:
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.max_target_embeddings = max_target_embeddings
self.num_decoder_layers = num_decoder_layers
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as DecoderBertLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class DecoderBertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(DecoderBertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
bert_model_embedding_weights.size(0),
bias=False)
self.decoder.weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertOnlyMLMHead, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class PreTrainedDecoderBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedDecoderBertModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, DecoderBertLayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def resize_token_embeddings(self, new_num_tokens=None):
model = self.module if hasattr(self, 'module') else self
base_model = getattr(model, "bert", model) # get the base model if needed
old_embeddings = base_model.embeddings.word_embeddings
# get_resized_embeddings <-----
if new_num_tokens is None:
return old_embeddings
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
if old_num_tokens == new_num_tokens:
return old_embeddings
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
device = old_embeddings.weight.device
new_embeddings.to(device)
# initialize all new embeddings (in particular added tokens)
self.init_bert_weights(new_embeddings)
# Copy word embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
# ----->
# Update base model and current model config
base_model.embeddings.word_embeddings = new_embeddings
base_model.config.vocab_size = new_num_tokens
model.config.vocab_size = new_num_tokens
model.cls = BertOnlyMLMHead(model.config, base_model.embeddings.word_embeddings.weight).to(device)
@classmethod
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
'''
abstract from `from_pretrained`
'''
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
if os.path.exists(archive_file) is False:
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
if task_config is None or task_config.local_rank == 0:
logger.error(
"Model name '{}' was not found in model name list. "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
archive_file))
return None
if resolved_archive_file == archive_file:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {}".format(archive_file))
else:
if task_config is None or task_config.local_rank == 0:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
if task_config is None or task_config.local_rank == 0:
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = BertConfig.from_json_file(config_file)
config.type_vocab_size = type_vocab_size
if task_config is None or task_config.local_rank == 0:
logger.info("Model config {}".format(config))
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
if os.path.exists(weights_path):
state_dict = torch.load(weights_path)
else:
if task_config is None or task_config.local_rank == 0:
logger.info("Weight doesn't exsits. {}".format(weights_path))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return config, state_dict
@classmethod
def init_preweight(cls, model, state_dict):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
if len(error_msgs) > 0:
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
return model
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None,
cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
"""
Instantiate a PreTrainedDecoderBertModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_classes for BertForSequenceClassification)
"""
config, state_dict = PreTrainedDecoderBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict)
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
return model
model = PreTrainedDecoderBertModel.init_preweight(model, state_dict)
return model
# import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
attn = self.dropout(attn)
output = torch.bmm(attn, v)
return output, attn
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, config):
super(MultiHeadAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, q, k, v, attention_mask):
mixed_query_layer = self.query(q)
mixed_key_layer = self.key(k)
mixed_value_layer = self.value(v)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer, attention_scores
class PositionwiseFeedForward(nn.Module):
''' A two-feed-forward-layer module '''
def __init__(self, d_in, d_hid, dropout=0.1):
super().__init__()
self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
self.layer_norm = nn.LayerNorm(d_in)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
output = x.transpose(1, 2)
output = self.w_2(gelu(self.w_1(output)))
output = output.transpose(1, 2)
output = self.dropout(output)
output = self.layer_norm(output + residual)
return output
class DecoderAttention(nn.Module):
def __init__(self, config):
super(DecoderAttention, self).__init__()
self.att = MultiHeadAttention(config)
self.output = BertSelfOutput(config)
def forward(self, q, k, v, attention_mask):
att_output, attention_probs = self.att(q, k, v, attention_mask)
attention_output = self.output(att_output, q)
return attention_output, attention_probs
class DecoderLayer(nn.Module):
def __init__(self, config):
super(DecoderLayer, self).__init__()
self.slf_attn = DecoderAttention(config)
self.enc_attn = DecoderAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None):
slf_output, _ = self.slf_attn(dec_input, dec_input, dec_input, slf_attn_mask)
dec_output, dec_att_scores = self.enc_attn(slf_output, enc_output, enc_output, dec_enc_attn_mask)
intermediate_output = self.intermediate(dec_output)
dec_output = self.output(intermediate_output, dec_output)
return dec_output, dec_att_scores
class BertDecoderEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config, bert_word_embeddings_weight, bert_position_embeddings_weight):
super(BertDecoderEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_target_embeddings, config.hidden_size)
self.word_embeddings.weight = bert_word_embeddings_weight
self.position_embeddings.weight = bert_position_embeddings_weight
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertDecoder(nn.Module):
def __init__(self, config):
super(BertDecoder, self).__init__()
layer = DecoderLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_decoder_layers)])
def forward(self, hidden_states, encoder_outs, self_attn_mask, attention_mask, output_all_encoded_layers=False):
dec_att_scores = None
all_encoder_layers = []
all_dec_att_probs = []
for layer_module in self.layer:
hidden_states, dec_att_scores = layer_module(hidden_states, encoder_outs, self_attn_mask, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
all_dec_att_probs.append(dec_att_scores)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
all_dec_att_probs.append(dec_att_scores)
return all_encoder_layers, all_dec_att_probs
class BertDecoderClassifier(nn.Module):
def __init__(self, config, embedding_weights):
super(BertDecoderClassifier, self).__init__()
self.cls = BertOnlyMLMHead(config, embedding_weights)
def forward(self, hidden_states):
cls_scores = self.cls(hidden_states)
return cls_scores
class DecoderBertModel(nn.Module):
def init_decoder_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, DecoderBertLayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
final_norm (bool, optional): apply layer norm to the output of the
final decoder layer (default: True).
"""
def __init__(self, config, bert_word_embeddings_weight, bert_position_embeddings_weight):
super(DecoderBertModel, self).__init__()
self.config = config
self.max_target_length = config.max_target_embeddings
self.embeddings = BertDecoderEmbeddings(config, bert_word_embeddings_weight, bert_position_embeddings_weight)
self.decoder = BertDecoder(config)
self.classifier = BertDecoderClassifier(config, bert_word_embeddings_weight)
self.apply(self.init_decoder_bert_weights)
def forward(self, input_ids, encoder_input_ids, encoder_outs=None, answer_mask=None, encoder_mask=None):
"""
Args:
input_ids (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing
encoder_outs (Tensor, optional): output from the encoder, used for encoder-side attention
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len, vocab)`
- the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)`
"""
embedding_output = self.embeddings(input_ids)
extended_encoder_mask = encoder_mask.unsqueeze(1).unsqueeze(2) # b x 1 x 1 x ls
extended_encoder_mask = extended_encoder_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_encoder_mask = (1.0 - extended_encoder_mask) * -10000.0
extended_answer_mask = answer_mask.unsqueeze(1).unsqueeze(2)
extended_answer_mask = extended_answer_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
sz_b, len_s, _ = embedding_output.size()
subsequent_mask = torch.triu(torch.ones((len_s, len_s), device=embedding_output.device, dtype=embedding_output.dtype), diagonal=1)
self_attn_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1).unsqueeze(1) # b x 1 x ls x ls
slf_attn_mask = ((1.0 - extended_answer_mask) + self_attn_mask).gt(0).to(dtype=next(self.parameters()).dtype)
self_attn_mask = slf_attn_mask * -10000.0
decoded_layers, dec_att_scores = self.decoder(embedding_output,
encoder_outs,
self_attn_mask,
extended_encoder_mask,
)
sequence_output = decoded_layers[-1]
cls_scores = self.classifier(sequence_output)
return cls_scores

5
requirements.txt Normal file
Просмотреть файл

@ -0,0 +1,5 @@
torch==1.7.0
tqdm
boto3
requests
pandas

16
util.py
Просмотреть файл

@ -2,6 +2,7 @@ import torch
import torch.nn as nn
import threading
from torch._utils import ExceptionWrapper
import logging
def get_a_var(obj):
if isinstance(obj, torch.Tensor):
@ -56,4 +57,17 @@ def parallel_apply(fct, model, inputs, device_ids):
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs
return outputs
def get_logger(filename=None):
logger = logging.getLogger('logger')
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
if filename is not None:
handler = logging.FileHandler(filename)
handler.setLevel(logging.DEBUG)
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
logging.getLogger().addHandler(handler)
return logger