update
This commit is contained in:
Родитель
ba6b9ea0f1
Коммит
4308eaff6c
|
@ -1 +1,2 @@
|
|||
.idea
|
||||
READMEINS.md
|
172
README.md
172
README.md
|
@ -1,9 +1,10 @@
|
|||
The implementation of paper [**UniVL: A Unified Video and Language Pre-Training Model for Multimodal Understanding and Generation**](https://arxiv.org/abs/2002.06353).
|
||||
|
||||
# Preliminary
|
||||
Excute below scripts in main folder firstly.
|
||||
Excute below scripts in the main folder firstly.
|
||||
```
|
||||
cd pytorch_pretrained_bert/bert-base-uncased/
|
||||
mkdir modules/bert-base-uncased
|
||||
cd modules/bert-base-uncased/
|
||||
wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt
|
||||
mv bert-base-uncased-vocab.txt vocab.txt
|
||||
wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz
|
||||
|
@ -12,81 +13,172 @@ rm bert-base-uncased.tar.gz
|
|||
cd ../../
|
||||
```
|
||||
|
||||
# Finetune on YoucookII
|
||||
## Retrieval
|
||||
# Requirements
|
||||
- python==3.6.9
|
||||
- torch==1.7.0+cu92
|
||||
- tqdm
|
||||
- boto3
|
||||
- requests
|
||||
- pandas
|
||||
- nlg-eval (Install Java 1.8.0 (or higher) firstly)
|
||||
```
|
||||
INIT_MODEL=<from second phase>
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python train_retrieval_task.py \
|
||||
--do_train --num_thread_reader=16 \
|
||||
--epochs=5 --batch_size=32 \
|
||||
--n_pair=1 --n_display=100 \
|
||||
--youcook_train_csv data/youcookii_train.csv \
|
||||
--youcook_val_csv data/youcookii_val.csv \
|
||||
--youcook_caption_path data/youcookii_caption.pickle \
|
||||
--youcook_features_path_2D data/youcookii_videos_features \
|
||||
--output_dir ckpt_youcook_retrieval --bert_model bert-base-uncased 、
|
||||
--do_lower_case --lr 3e-5 --max_words 48 --max_frames 48 \
|
||||
--batch_size_val 200 --visual_num_hidden_layers 6 \
|
||||
--datatype youcook --init_model ${INIT_MODEL}
|
||||
conda create -n py_univl python=3.6.9 tqdm boto3 requests pandas
|
||||
conda activate py_univl
|
||||
pip install torch==1.7.1+cu92
|
||||
pip install git+https://github.com/Maluuba/nlg-eval.git@master
|
||||
```
|
||||
|
||||
## Caption
|
||||
# Finetune on YoucookII
|
||||
## Retrieval
|
||||
|
||||
1. Run retrieval task on **YoucookII**
|
||||
|
||||
```
|
||||
INIT_MODEL=<from second phase>
|
||||
python -m torch.distributed.launch --nproc_per_node=4 train_transcript_distributed.py \
|
||||
DATATYPE="youcook"
|
||||
TRAIN_CSV="data/youcookii/youcookii_train.csv"
|
||||
VAL_CSV="data/youcookii/youcookii_val.csv"
|
||||
DATA_PATH="data/youcookii/youcookii_data.pickle"
|
||||
FEATURES_PATH="data/youcookii/youcookii_videos_features.pickle"
|
||||
INIT_MODEL="weight/univl.pretrained.bin"
|
||||
OUTPUT_ROOT="ckpts"
|
||||
|
||||
python -m torch.distributed.launch --nproc_per_node=4 \
|
||||
main_task_retrieval.py \
|
||||
--do_train --num_thread_reader=16 \
|
||||
--epochs=5 --batch_size=32 \
|
||||
--n_display=100 \
|
||||
--train_csv ${TRAIN_CSV} \
|
||||
--val_csv ${VAL_CSV} \
|
||||
--data_path ${DATA_PATH} \
|
||||
--features_path ${FEATURES_PATH} \
|
||||
--output_dir ${OUTPUT_ROOT}/ckpt_youcook_retrieval --bert_model bert-base-uncased \
|
||||
--do_lower_case --lr 3e-5 --max_words 48 --max_frames 48 \
|
||||
--batch_size_val 200 --visual_num_hidden_layers 6 \
|
||||
--datatype ${DATATYPE} --init_model ${INIT_MODEL}
|
||||
```
|
||||
>The results are close to *R@1: 0.2269 - R@5: 0.5245 - R@10: 0.6586 - Median R: 5.0*
|
||||
|
||||
2. Run retrieval task on **MSRVTT**
|
||||
```
|
||||
DATATYPE="msrvtt"
|
||||
TRAIN_CSV="data/msrvtt/MSRVTT_train.9k.csv"
|
||||
VAL_CSV="data/msrvtt/MSRVTT_JSFUSION_test.csv"
|
||||
DATA_PATH="data/msrvtt/MSRVTT_data.json"
|
||||
FEATURES_PATH="data/msrvtt/msrvtt_videos_features.pickle"
|
||||
INIT_MODEL="weight/univl.pretrained.bin"
|
||||
OUTPUT_ROOT="ckpts"
|
||||
|
||||
python -m torch.distributed.launch --nproc_per_node=4 \
|
||||
main_task_retrieval.py \
|
||||
--do_train --num_thread_reader=16 \
|
||||
--epochs=5 --batch_size=128 \
|
||||
--n_display=100 \
|
||||
--train_csv ${TRAIN_CSV} \
|
||||
--val_csv ${VAL_CSV} \
|
||||
--data_path ${DATA_PATH} \
|
||||
--features_path ${FEATURES_PATH} \
|
||||
--output_dir ${OUTPUT_ROOT}/ckpt_msrvtt_retrieval --bert_model bert-base-uncased \
|
||||
--do_lower_case --lr 5e-5 --max_words 48 --max_frames 48 \
|
||||
--batch_size_val 200 --visual_num_hidden_layers 6 \
|
||||
--datatype ${DATATYPE} --expand_msrvtt_sentences --init_model ${INIT_MODEL}
|
||||
```
|
||||
>The results are close to *R@1: 0.2720 - R@5: 0.5570 - R@10: 0.6870 - Median R: 4.0*
|
||||
|
||||
## Caption
|
||||
Run caption task on **YoucookII**
|
||||
|
||||
```
|
||||
TRAIN_CSV="data/youcookii/youcookii_train.csv"
|
||||
VAL_CSV="data/youcookii/youcookii_val.csv"
|
||||
DATA_PATH="data/youcookii/youcookii_data.pickle"
|
||||
FEATURES_PATH="data/youcookii/youcookii_videos_features.pickle"
|
||||
INIT_MODEL="weight/univl.pretrained.bin"
|
||||
OUTPUT_ROOT="ckpts"
|
||||
|
||||
python -m torch.distributed.launch --nproc_per_node=4 \
|
||||
main_task_caption.py \
|
||||
--do_train --num_thread_reader=4 \
|
||||
--epochs=5 --batch_size=16 \
|
||||
--n_pair=-1 --n_display=100 \
|
||||
--youcook_train_csv data/youcookii_train.csv \
|
||||
--youcook_val_csv data/youcookii_val.csv \
|
||||
--youcook_caption_path data/youcookii_caption_transcript.pickle \
|
||||
--youcook_features_path_2D data/youcookii_videos_features \
|
||||
--output_dir ckpt_youcook_caption --bert_model bert-base-uncased \
|
||||
--n_display=100 \
|
||||
--train_csv ${TRAIN_CSV} \
|
||||
--val_csv ${VAL_CSV} \
|
||||
--data_path ${DATA_PATH} \
|
||||
--features_path ${FEATURES_PATH} \
|
||||
--output_dir ${OUTPUT_ROOT}/ckpt_youcook_caption --bert_model bert-base-uncased \
|
||||
--do_lower_case --lr 3e-5 --max_words 128 --max_frames 96 \
|
||||
--batch_size_val 64 --visual_num_hidden_layers 6 \
|
||||
--decoder_num_hidden_layers 3 \
|
||||
--decoder_num_hidden_layers 3 --stage_two \
|
||||
--init_model ${INIT_MODEL}
|
||||
```
|
||||
>The results are close to
|
||||
```
|
||||
BLEU_1: 0.4746, BLEU_2: 0.3355, BLEU_3: 0.2423, BLEU_4: 0.1779
|
||||
METEOR: 0.2261, ROUGE_L: 0.4697, CIDEr: 1.8631
|
||||
```
|
||||
|
||||
# Pretrain on HowTo100M
|
||||
## Phase I
|
||||
|
||||
## Format of csv
|
||||
```
|
||||
video_id,feature_file
|
||||
Z8xhli297v8,Z8xhli297v8.npy
|
||||
...
|
||||
```
|
||||
|
||||
## Stage I
|
||||
```
|
||||
ROOT_PATH=.
|
||||
DATA_PATH=${ROOT_PATH}/data
|
||||
SAVE_PATH=${ROOT_PATH}/models
|
||||
MODEL_PATH=${ROOT_PATH}/UniVL
|
||||
python -m torch.distributed.launch --nproc_per_node=8 ${MODEL_PATH}/train_transcript_distributed.py \
|
||||
python -m torch.distributed.launch --nproc_per_node=8 \
|
||||
${MODEL_PATH}/main_pretrain.py \
|
||||
--do_pretrain --num_thread_reader=0 --epochs=50 \
|
||||
--batch_size=1920 --n_pair=3 --n_display=100 \
|
||||
--bert_model bert-base-uncased --do_lower_case --lr 1e-4 \
|
||||
--max_words 48 --max_frames 64 --batch_size_val 344 \
|
||||
--output_dir ${SAVE_PATH}/pre_trained/pre_s3d_L48_V6_D3_Phase1 \
|
||||
--features_path_2D ${DATA_PATH}/features \
|
||||
--output_dir ${SAVE_PATH}/pre_trained/L48_V6_D3_Phase1 \
|
||||
--features_path ${DATA_PATH}/features \
|
||||
--train_csv ${DATA_PATH}/HowTo100M.csv \
|
||||
--caption_path ${DATA_PATH}/caption.pickle \
|
||||
--data_path ${DATA_PATH}/caption.pickle \
|
||||
--visual_num_hidden_layers 6 --gradient_accumulation_steps 16 \
|
||||
--sampled_use_mil --load_checkpoint
|
||||
```
|
||||
|
||||
## Phase II
|
||||
## Stage II
|
||||
```
|
||||
ROOT_PATH=.
|
||||
DATA_PATH=${ROOT_PATH}/data
|
||||
SAVE_PATH=${ROOT_PATH}/models
|
||||
MODEL_PATH=${ROOT_PATH}/UniVL
|
||||
INIT_MODEL=<from first phase>
|
||||
python -m torch.distributed.launch --nproc_per_node=8 ${MODEL_PATH}/train_transcript_distributed.py \
|
||||
INIT_MODEL=<from first stage>
|
||||
python -m torch.distributed.launch --nproc_per_node=8 \
|
||||
${MODEL_PATH}/main_pretrain.py \
|
||||
--do_pretrain --num_thread_reader=0 --epochs=50 \
|
||||
--batch_size=960 --n_pair=3 --n_display=100 \
|
||||
--bert_model bert-base-uncased --do_lower_case --lr 1e-4 \
|
||||
--max_words 48 --max_frames 64 --batch_size_val 344 \
|
||||
--output_dir ${SAVE_PATH}/pre_trained/pre_s3d_L48_V6_D3_Phase2 \
|
||||
--features_path_2D ${DATA_PATH}/features \
|
||||
--output_dir ${SAVE_PATH}/pre_trained/L48_V6_D3_Phase2 \
|
||||
--features_path ${DATA_PATH}/features \
|
||||
--train_csv ${DATA_PATH}/HowTo100M.csv \
|
||||
--caption_path ${DATA_PATH}/caption.pickle \
|
||||
--data_path ${DATA_PATH}/caption.pickle \
|
||||
--visual_num_hidden_layers 6 --decoder_num_hidden_layers 3 \
|
||||
--gradient_accumulation_steps 60 \
|
||||
--cross_model --pretrain_with_joint_sim --sampled_use_mil \
|
||||
--pretrain_enhance_vmodal --pretrain_without_decoder \
|
||||
--stage_two --pretrain_with_joint_sim --sampled_use_mil \
|
||||
--pretrain_enhance_vmodal \
|
||||
--load_checkpoint --init_model ${INIT_MODEL}
|
||||
```
|
||||
|
||||
# Citation
|
||||
If you find UniVL useful in your work, you can cite the following paper:
|
||||
```
|
||||
@Article{Luo2020UniVL,
|
||||
author = {Huaishao Luo and Lei Ji and Botian Shi and Haoyang Huang and Nan Duan and Tianrui Li and Jason Li and Taroon Bharti and Ming Zhou},
|
||||
title = {UniVL: A Unified Video and Language Pre-Training Model for Multimodal Understanding and Generation},
|
||||
journal = {arXiv preprint arXiv:2002.06353},
|
||||
year = {2020},
|
||||
}
|
||||
```
|
||||
|
||||
# Acknowledgments
|
||||
Our code is based on [pytorch-transformers v0.4.0](https://github.com/huggingface/transformers/tree/v0.4.0). We thank the authors for their wonderful open-source efforts.
|
|
@ -1,9 +0,0 @@
|
|||
# pip install prefetch_generator
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from prefetch_generator import BackgroundGenerator
|
||||
|
||||
class DataLoaderX(DataLoader):
|
||||
def __iter__(self):
|
||||
# transforms generator into a background-thead generator.
|
||||
return BackgroundGenerator(super().__iter__(), max_prefetch=1)
|
|
@ -7,11 +7,9 @@ from torch.utils.data import Dataset
|
|||
import pandas as pd
|
||||
import os
|
||||
import numpy as np
|
||||
import re
|
||||
import random
|
||||
import io
|
||||
|
||||
class Youtube_Transcript_DataLoader(Dataset):
|
||||
class Youtube_DataLoader(Dataset):
|
||||
"""
|
||||
Youtube dataset loader.
|
||||
Note: Use transcript as caption, for mask decoder pretrain task.
|
||||
|
@ -21,16 +19,13 @@ class Youtube_Transcript_DataLoader(Dataset):
|
|||
self,
|
||||
csv,
|
||||
features_path,
|
||||
caption,
|
||||
data_dict,
|
||||
tokenizer,
|
||||
min_time=10.0,
|
||||
features_path_3D=None,
|
||||
feature_framerate=1.0,
|
||||
feature_framerate_3D=24.0 / 16.0,
|
||||
max_words=30,
|
||||
min_words=0,
|
||||
n_pair=-1, # -1 for test
|
||||
pad_token="[PAD]",
|
||||
n_pair=-1,
|
||||
max_frames=100,
|
||||
with_long_context=True,
|
||||
use_mil=False,
|
||||
|
@ -43,24 +38,16 @@ class Youtube_Transcript_DataLoader(Dataset):
|
|||
Args:
|
||||
"""
|
||||
self.csv = pd.read_csv(csv)
|
||||
self.features_path_2D = features_path
|
||||
self.features_path_3D = features_path_3D
|
||||
self.caption = caption
|
||||
self.features_path = features_path
|
||||
self.data_dict = data_dict
|
||||
self.min_time = min_time
|
||||
self.feature_framerate = feature_framerate
|
||||
self.feature_framerate_3D = feature_framerate_3D
|
||||
self.max_words = max_words
|
||||
self.max_frames = max_frames
|
||||
self.min_words = min_words
|
||||
self.tokenizer = tokenizer
|
||||
self.n_pair = n_pair
|
||||
self.pad_token = pad_token
|
||||
self.fps = {'2d': feature_framerate, '3d': feature_framerate_3D}
|
||||
self.feature_path = {'2d': features_path}
|
||||
if features_path_3D != '':
|
||||
self.feature_path['3d'] = features_path_3D
|
||||
self.with_long_context = with_long_context
|
||||
|
||||
self.feature_size = video_dim
|
||||
|
||||
self.only_sim = only_sim
|
||||
|
@ -69,7 +56,7 @@ class Youtube_Transcript_DataLoader(Dataset):
|
|||
|
||||
self.use_mil = use_mil
|
||||
self.sampled_use_mil = sampled_use_mil
|
||||
if self.sampled_use_mil: # sample from each video, has a high priority than use_mil.
|
||||
if self.sampled_use_mil: # sample from each video, has a higher priority than use_mil.
|
||||
self.use_mil = True
|
||||
|
||||
if self.use_mil:
|
||||
|
@ -82,8 +69,8 @@ class Youtube_Transcript_DataLoader(Dataset):
|
|||
self.iter2video_pairslist_dict = {}
|
||||
iter_idx_mil_ = 0
|
||||
for video_id in video_id_list:
|
||||
caption = self.caption[video_id]
|
||||
n_caption = len(caption['start'])
|
||||
data_dict = self.data_dict[video_id]
|
||||
n_caption = len(data_dict['start'])
|
||||
|
||||
sub_list = []
|
||||
if self.n_pair < 0 or self.n_pair == 1:
|
||||
|
@ -114,46 +101,38 @@ class Youtube_Transcript_DataLoader(Dataset):
|
|||
def __len__(self):
|
||||
return self.iter_num
|
||||
|
||||
def _mask_tokens(self, words, orig2token_tuple_list=None, chunk_positions=None):
|
||||
def _mask_tokens(self, words):
|
||||
token_labels = []
|
||||
masked_tokens = words.copy()
|
||||
|
||||
if chunk_positions is not None and len(chunk_positions)>0:
|
||||
token_labels = [-1] * len(masked_tokens)
|
||||
for chunk_ind in chunk_positions:
|
||||
start_, len_ = orig2token_tuple_list[chunk_ind]
|
||||
if random.random() < 0.15:
|
||||
token_labels[start_:start_+len_] = [self.tokenizer.vocab[tk_] for tk_ in masked_tokens[start_:start_+len_]]
|
||||
masked_tokens[start_:start_+len_] = ["[MASK]"] * len_
|
||||
else:
|
||||
for token_id, token in enumerate(masked_tokens):
|
||||
if token_id == 0 or token_id == len(masked_tokens) - 1:
|
||||
token_labels.append(-1)
|
||||
continue
|
||||
prob = random.random()
|
||||
if prob < 0.15:
|
||||
prob /= 0.15
|
||||
if prob < 0.8:
|
||||
masked_tokens[token_id] = "[MASK]"
|
||||
elif prob < 0.9:
|
||||
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
|
||||
try:
|
||||
token_labels.append(self.tokenizer.vocab[token])
|
||||
except KeyError:
|
||||
token_labels.append(self.tokenizer.vocab["[UNK]"])
|
||||
else:
|
||||
token_labels.append(-1)
|
||||
for token_id, token in enumerate(masked_tokens):
|
||||
if token_id == 0 or token_id == len(masked_tokens) - 1:
|
||||
token_labels.append(-1)
|
||||
continue
|
||||
prob = random.random()
|
||||
if prob < 0.15:
|
||||
prob /= 0.15
|
||||
if prob < 0.8:
|
||||
masked_tokens[token_id] = "[MASK]"
|
||||
elif prob < 0.9:
|
||||
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
|
||||
try:
|
||||
token_labels.append(self.tokenizer.vocab[token])
|
||||
except KeyError:
|
||||
token_labels.append(self.tokenizer.vocab["[UNK]"])
|
||||
else:
|
||||
token_labels.append(-1)
|
||||
|
||||
return masked_tokens, token_labels
|
||||
|
||||
def _get_text(self, video_id, n_pair_max, sub_ids=None, only_sim=False, enhance_vmodel=False):
|
||||
caption = self.caption[video_id]
|
||||
data_dict = self.data_dict[video_id]
|
||||
|
||||
if self.use_mil:
|
||||
k = len(sub_ids)
|
||||
r_ind = sub_ids
|
||||
else:
|
||||
n_caption = len(caption['start'])
|
||||
n_caption = len(data_dict['start'])
|
||||
if n_pair_max == -1:
|
||||
k = n_caption
|
||||
r_ind = range(n_caption)
|
||||
|
@ -181,16 +160,10 @@ class Youtube_Transcript_DataLoader(Dataset):
|
|||
|
||||
for i in range(k):
|
||||
ind = r_ind[i]
|
||||
words, start_, end_ = self._get_single_transcript(caption, ind, with_long_context=self.with_long_context)
|
||||
words, start_, end_ = self._get_single_transcript(data_dict, ind, with_long_context=self.with_long_context)
|
||||
caption_words = words.copy()
|
||||
starts[i], ends[i] = start_, end_
|
||||
|
||||
orig2token_tuple_list, chunk_positions = None, None
|
||||
# # For entity mask
|
||||
# orig_sentence, orig_to_token_map = generate_org_sentence_from_tokens(words.copy())
|
||||
# four_tuples_ = generate_knoledge_candidate(orig_sentence, orig_to_token_map)
|
||||
# _, orig2token_tuple_list, chunk_positions, _ = four_tuples_
|
||||
|
||||
if enhance_vmodel:
|
||||
words = [] # mask all input text
|
||||
|
||||
|
@ -199,7 +172,7 @@ class Youtube_Transcript_DataLoader(Dataset):
|
|||
if len(words) > total_length_with_CLS:
|
||||
words = words[:total_length_with_CLS]
|
||||
words = words + ["[SEP]"]
|
||||
|
||||
|
||||
input_ids = self.tokenizer.convert_tokens_to_ids(words)
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = [0] * len(input_ids)
|
||||
|
@ -216,17 +189,13 @@ class Youtube_Transcript_DataLoader(Dataset):
|
|||
pairs_segment[i] = np.array(segment_ids)
|
||||
|
||||
if only_sim is False:
|
||||
# # For entity mask : +1 due to [CLS]
|
||||
# orig2token_tuple_list = [(s_+1, len_) for s_, len_ in orig2token_tuple_list if s_+len_ <= total_length_with_CLS]
|
||||
# chunk_positions = [ind for ind in chunk_positions if ind<len(orig2token_tuple_list)]
|
||||
|
||||
# For generate captions
|
||||
if len(caption_words) > total_length_with_CLS:
|
||||
caption_words = caption_words[:total_length_with_CLS]
|
||||
input_caption_words = ["[CLS]"] + caption_words
|
||||
output_caption_words = caption_words + ["[SEP]"]
|
||||
|
||||
masked_tokens, token_labels = self._mask_tokens(words, orig2token_tuple_list, chunk_positions)
|
||||
masked_tokens, token_labels = self._mask_tokens(words)
|
||||
masked_token_ids = self.tokenizer.convert_tokens_to_ids(masked_tokens)
|
||||
masked_input_caption_words, input_token_labels = self._mask_tokens(input_caption_words)
|
||||
input_caption_words = masked_input_caption_words.copy()
|
||||
|
@ -259,16 +228,16 @@ class Youtube_Transcript_DataLoader(Dataset):
|
|||
return pairs_text, pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels, \
|
||||
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids, starts, ends
|
||||
|
||||
def _get_single_transcript(self, caption, ind, with_long_context=True):
|
||||
def _get_single_transcript(self, data_dict, ind, with_long_context=True):
|
||||
start, end = ind, ind
|
||||
words = self.tokenizer.tokenize(str(caption['text'][ind]))
|
||||
diff = caption['end'][end] - caption['start'][start]
|
||||
words = self.tokenizer.tokenize(str(data_dict['text'][ind]))
|
||||
diff = data_dict['end'][end] - data_dict['start'][start]
|
||||
while with_long_context and (len(words) < self.min_words or diff < self.min_time):
|
||||
if start > 0 and end < len(caption['end']) - 1:
|
||||
next_words = self.tokenizer.tokenize(str(caption['text'][end + 1]))
|
||||
prev_words = self.tokenizer.tokenize(str(caption['text'][start - 1]))
|
||||
d1 = caption['end'][end + 1] - caption['start'][start]
|
||||
d2 = caption['end'][end] - caption['start'][start - 1]
|
||||
if start > 0 and end < len(data_dict['end']) - 1:
|
||||
next_words = self.tokenizer.tokenize(str(data_dict['text'][end + 1]))
|
||||
prev_words = self.tokenizer.tokenize(str(data_dict['text'][start - 1]))
|
||||
d1 = data_dict['end'][end + 1] - data_dict['start'][start]
|
||||
d2 = data_dict['end'][end] - data_dict['start'][start - 1]
|
||||
if (self.min_time > 0 and d2 <= d1) or \
|
||||
(self.min_time == 0 and len(next_words) <= len(prev_words)):
|
||||
start -= 1
|
||||
|
@ -277,17 +246,17 @@ class Youtube_Transcript_DataLoader(Dataset):
|
|||
end += 1
|
||||
words.extend(next_words)
|
||||
elif start > 0:
|
||||
words = self.tokenizer.tokenize(str(caption['text'][start - 1])) + words
|
||||
words = self.tokenizer.tokenize(str(data_dict['text'][start - 1])) + words
|
||||
start -= 1
|
||||
elif end < len(caption['end']) - 1:
|
||||
words.extend(self.tokenizer.tokenize(str(caption['text'][end + 1])))
|
||||
elif end < len(data_dict['end']) - 1:
|
||||
words.extend(self.tokenizer.tokenize(str(data_dict['text'][end + 1])))
|
||||
end += 1
|
||||
else:
|
||||
break
|
||||
diff = caption['end'][end] - caption['start'][start]
|
||||
return words, caption['start'][start], caption['end'][end]
|
||||
diff = data_dict['end'][end] - data_dict['start'][start]
|
||||
return words, data_dict['start'][start], data_dict['end'][end]
|
||||
|
||||
def _expand_video_slice(self, s, e, si, ei, fps, video_features, fps_k):
|
||||
def _expand_video_slice(self, s, e, si, ei, fps, video_features):
|
||||
start = int(s[si] * fps)
|
||||
end = int(e[ei] * fps) + 1
|
||||
|
||||
|
@ -311,12 +280,6 @@ class Youtube_Transcript_DataLoader(Dataset):
|
|||
start, end = end, start
|
||||
video_slice = video_features[start:end]
|
||||
|
||||
# # Note: to alignment the features between 2d and 3d, due to the fps is different.
|
||||
# if fps_k == "2d":
|
||||
# indices = sorted([idx for idx in range(len(video_slice)) if idx % 2 == 1]
|
||||
# + [idx for idx in range(len(video_slice))])
|
||||
# video_slice = np.take(video_slice, indices, axis=0)
|
||||
|
||||
if self.max_frames < video_slice.shape[0]:
|
||||
video_slice = video_slice[:self.max_frames]
|
||||
|
||||
|
@ -326,32 +289,20 @@ class Youtube_Transcript_DataLoader(Dataset):
|
|||
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
|
||||
|
||||
max_video_length = [0] * len(s)
|
||||
f_title = "feature_file_2D"
|
||||
fps_k = "2d"
|
||||
|
||||
video = np.zeros((len(s), self.max_frames, self.feature_size), dtype=np.float)
|
||||
feature_file = os.path.join(self.feature_path[fps_k], self.csv[f_title].values[idx])
|
||||
feature_file = os.path.join(self.features_path, self.csv["feature_file"].values[idx])
|
||||
try:
|
||||
video_features = np.load(feature_file)
|
||||
|
||||
for i in range(len(s)):
|
||||
# start = int(s[i] * self.fps[fps_k])
|
||||
# end = int(e[i] * self.fps[fps_k]) + 1
|
||||
# video_slice = video_features[start:end]
|
||||
#
|
||||
# if self.max_frames < video_slice.shape[0]:
|
||||
# video_slice = video_slice[:self.max_frames]
|
||||
if len(video_features) < 1:
|
||||
raise ValueError("{} is empty.".format(feature_file))
|
||||
video_slice, start, end = self._expand_video_slice(s, e, i, i,
|
||||
self.fps[fps_k], video_features, fps_k)
|
||||
video_slice, start, end = self._expand_video_slice(s, e, i, i, self.feature_framerate, video_features)
|
||||
slice_shape = video_slice.shape
|
||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
|
||||
if len(video_slice) < 1:
|
||||
pass
|
||||
# video_slice = video_features
|
||||
# if self.max_frames < video_slice.shape[0]:
|
||||
# video_slice = video_slice[:self.max_frames]
|
||||
else:
|
||||
video[i][:slice_shape[0]] = video_slice
|
||||
except Exception as e:
|
||||
|
@ -387,7 +338,7 @@ class Youtube_Transcript_DataLoader(Dataset):
|
|||
return "%02d:%02d:%02d" % (h, m2, s)
|
||||
|
||||
def __getitem__(self, feature_idx):
|
||||
if self.sampled_use_mil: # sample from each video, has a high priority than use_mil.
|
||||
if self.sampled_use_mil: # sample from each video, has a higher priority than use_mil.
|
||||
idx = feature_idx
|
||||
video_id = self.csv['video_id'].values[idx]
|
||||
sub_list = self.iter2video_pairslist_dict[video_id]
|
|
@ -0,0 +1,352 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import random
|
||||
|
||||
class MSRVTT_DataLoader(Dataset):
|
||||
"""MSRVTT dataset loader."""
|
||||
def __init__(
|
||||
self,
|
||||
csv_path,
|
||||
features_path,
|
||||
tokenizer,
|
||||
max_words=30,
|
||||
feature_framerate=1.0,
|
||||
max_frames=100,
|
||||
):
|
||||
self.data = pd.read_csv(csv_path)
|
||||
self.feature_dict = pickle.load(open(features_path, 'rb'))
|
||||
self.feature_framerate = feature_framerate
|
||||
self.max_words = max_words
|
||||
self.max_frames = max_frames
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.feature_size = self.feature_dict[self.data['video_id'].values[0]].shape[-1]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _get_text(self, video_id, sentence):
|
||||
choice_video_ids = [video_id]
|
||||
n_caption = len(choice_video_ids)
|
||||
|
||||
k = n_caption
|
||||
pairs_text = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_masked_text = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_token_labels = np.zeros((k, self.max_words), dtype=np.long)
|
||||
|
||||
for i, video_id in enumerate(choice_video_ids):
|
||||
words = self.tokenizer.tokenize(sentence)
|
||||
|
||||
words = ["[CLS]"] + words
|
||||
total_length_with_CLS = self.max_words - 1
|
||||
if len(words) > total_length_with_CLS:
|
||||
words = words[:total_length_with_CLS]
|
||||
words = words + ["[SEP]"]
|
||||
|
||||
# Mask Language Model <-----
|
||||
token_labels = []
|
||||
masked_tokens = words.copy()
|
||||
for token_id, token in enumerate(masked_tokens):
|
||||
if token_id == 0 or token_id == len(masked_tokens) - 1:
|
||||
token_labels.append(-1)
|
||||
continue
|
||||
prob = random.random()
|
||||
# mask token with 15% probability
|
||||
if prob < 0.15:
|
||||
prob /= 0.15
|
||||
# 80% randomly change token to mask token
|
||||
if prob < 0.8:
|
||||
masked_tokens[token_id] = "[MASK]"
|
||||
# 10% randomly change token to random token
|
||||
elif prob < 0.9:
|
||||
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
|
||||
# -> rest 10% randomly keep current token
|
||||
# append current token to output (we will predict these later)
|
||||
try:
|
||||
token_labels.append(self.tokenizer.vocab[token])
|
||||
except KeyError:
|
||||
# For unknown words (should not occur with BPE vocab)
|
||||
token_labels.append(self.tokenizer.vocab["[UNK]"])
|
||||
# print("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
|
||||
else:
|
||||
# no masking token (will be ignored by loss function later)
|
||||
token_labels.append(-1)
|
||||
# -----> Mask Language Model
|
||||
|
||||
input_ids = self.tokenizer.convert_tokens_to_ids(words)
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = [0] * len(input_ids)
|
||||
masked_token_ids = self.tokenizer.convert_tokens_to_ids(masked_tokens)
|
||||
while len(input_ids) < self.max_words:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
masked_token_ids.append(0)
|
||||
token_labels.append(-1)
|
||||
assert len(input_ids) == self.max_words
|
||||
assert len(input_mask) == self.max_words
|
||||
assert len(segment_ids) == self.max_words
|
||||
assert len(masked_token_ids) == self.max_words
|
||||
assert len(token_labels) == self.max_words
|
||||
|
||||
pairs_text[i] = np.array(input_ids)
|
||||
pairs_mask[i] = np.array(input_mask)
|
||||
pairs_segment[i] = np.array(segment_ids)
|
||||
pairs_masked_text[i] = np.array(masked_token_ids)
|
||||
pairs_token_labels[i] = np.array(token_labels)
|
||||
|
||||
return pairs_text, pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels, choice_video_ids
|
||||
|
||||
def _get_video(self, choice_video_ids):
|
||||
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long)
|
||||
max_video_length = [0] * len(choice_video_ids)
|
||||
|
||||
video = np.zeros((len(choice_video_ids), self.max_frames, self.feature_size), dtype=np.float)
|
||||
for i, video_id in enumerate(choice_video_ids):
|
||||
video_slice = self.feature_dict[video_id]
|
||||
|
||||
if self.max_frames < video_slice.shape[0]:
|
||||
video_slice = video_slice[:self.max_frames]
|
||||
|
||||
slice_shape = video_slice.shape
|
||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
|
||||
if len(video_slice) < 1:
|
||||
print("video_id: {}".format(video_id))
|
||||
else:
|
||||
video[i][:slice_shape[0]] = video_slice
|
||||
|
||||
for i, v_length in enumerate(max_video_length):
|
||||
video_mask[i][:v_length] = [1] * v_length
|
||||
|
||||
# Mask Frame Model <-----
|
||||
video_labels_index = [[] for _ in range(len(choice_video_ids))]
|
||||
masked_video = video.copy()
|
||||
for i, video_pair_ in enumerate(masked_video):
|
||||
for j, _ in enumerate(video_pair_):
|
||||
if j < max_video_length[i]:
|
||||
prob = random.random()
|
||||
# mask token with 15% probability
|
||||
if prob < 0.15:
|
||||
masked_video[i][j] = [0.] * video.shape[-1]
|
||||
video_labels_index[i].append(j)
|
||||
else:
|
||||
video_labels_index[i].append(-1)
|
||||
else:
|
||||
video_labels_index[i].append(-1)
|
||||
video_labels_index = np.array(video_labels_index, dtype=np.long)
|
||||
# -----> Mask Frame Model
|
||||
|
||||
return video, video_mask, masked_video, video_labels_index
|
||||
|
||||
def __getitem__(self, idx):
|
||||
video_id = self.data['video_id'].values[idx]
|
||||
sentence = self.data['sentence'].values[idx]
|
||||
|
||||
pairs_text, pairs_mask, pairs_segment, \
|
||||
pairs_masked_text, pairs_token_labels, choice_video_ids = self._get_text(video_id, sentence)
|
||||
|
||||
video, video_mask, masked_video, video_labels_index = self._get_video(choice_video_ids)
|
||||
|
||||
return pairs_text, pairs_mask, pairs_segment, video, video_mask, \
|
||||
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index
|
||||
|
||||
class MSRVTT_TrainDataLoader(Dataset):
|
||||
"""MSRVTT train dataset loader."""
|
||||
def __init__(
|
||||
self,
|
||||
csv_path,
|
||||
json_path,
|
||||
features_path,
|
||||
tokenizer,
|
||||
max_words=30,
|
||||
feature_framerate=1.0,
|
||||
max_frames=100,
|
||||
unfold_sentences=False,
|
||||
):
|
||||
self.csv = pd.read_csv(csv_path)
|
||||
self.data = json.load(open(json_path, 'r'))
|
||||
self.feature_dict = pickle.load(open(features_path, 'rb'))
|
||||
self.feature_framerate = feature_framerate
|
||||
self.max_words = max_words
|
||||
self.max_frames = max_frames
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.feature_size = self.feature_dict[self.csv['video_id'].values[0]].shape[-1]
|
||||
|
||||
self.unfold_sentences = unfold_sentences
|
||||
self.sample_len = 0
|
||||
if self.unfold_sentences:
|
||||
train_video_ids = list(self.csv['video_id'].values)
|
||||
self.sentences_dict = {}
|
||||
for itm in self.data['sentences']:
|
||||
if itm['video_id'] in train_video_ids:
|
||||
self.sentences_dict[len(self.sentences_dict)] = (itm['video_id'], itm['caption'])
|
||||
self.sample_len = len(self.sentences_dict)
|
||||
else:
|
||||
num_sentences = 0
|
||||
self.sentences = defaultdict(list)
|
||||
s_video_id_set = set()
|
||||
for itm in self.data['sentences']:
|
||||
self.sentences[itm['video_id']].append(itm['caption'])
|
||||
num_sentences += 1
|
||||
s_video_id_set.add(itm['video_id'])
|
||||
|
||||
# Use to find the clips in the same video
|
||||
self.parent_ids = {}
|
||||
self.children_video_ids = defaultdict(list)
|
||||
for itm in self.data['videos']:
|
||||
vid = itm["video_id"]
|
||||
url_posfix = itm["url"].split("?v=")[-1]
|
||||
self.parent_ids[vid] = url_posfix
|
||||
self.children_video_ids[url_posfix].append(vid)
|
||||
self.sample_len = len(self.csv)
|
||||
|
||||
def __len__(self):
|
||||
return self.sample_len
|
||||
|
||||
def _get_text(self, video_id, caption=None):
|
||||
k = 1
|
||||
choice_video_ids = [video_id]
|
||||
pairs_text = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_segment = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_masked_text = np.zeros((k, self.max_words), dtype=np.long)
|
||||
pairs_token_labels = np.zeros((k, self.max_words), dtype=np.long)
|
||||
|
||||
for i, video_id in enumerate(choice_video_ids):
|
||||
if caption is not None:
|
||||
words = self.tokenizer.tokenize(caption)
|
||||
else:
|
||||
words = self._get_single_text(video_id)
|
||||
|
||||
words = ["[CLS]"] + words
|
||||
total_length_with_CLS = self.max_words - 1
|
||||
if len(words) > total_length_with_CLS:
|
||||
words = words[:total_length_with_CLS]
|
||||
words = words + ["[SEP]"]
|
||||
|
||||
# Mask Language Model <-----
|
||||
token_labels = []
|
||||
masked_tokens = words.copy()
|
||||
for token_id, token in enumerate(masked_tokens):
|
||||
if token_id == 0 or token_id == len(masked_tokens) - 1:
|
||||
token_labels.append(-1)
|
||||
continue
|
||||
prob = random.random()
|
||||
# mask token with 15% probability
|
||||
if prob < 0.15:
|
||||
prob /= 0.15
|
||||
# 80% randomly change token to mask token
|
||||
if prob < 0.8:
|
||||
masked_tokens[token_id] = "[MASK]"
|
||||
# 10% randomly change token to random token
|
||||
elif prob < 0.9:
|
||||
masked_tokens[token_id] = random.choice(list(self.tokenizer.vocab.items()))[0]
|
||||
# -> rest 10% randomly keep current token
|
||||
# append current token to output (we will predict these later)
|
||||
try:
|
||||
token_labels.append(self.tokenizer.vocab[token])
|
||||
except KeyError:
|
||||
# For unknown words (should not occur with BPE vocab)
|
||||
token_labels.append(self.tokenizer.vocab["[UNK]"])
|
||||
# print("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
|
||||
else:
|
||||
# no masking token (will be ignored by loss function later)
|
||||
token_labels.append(-1)
|
||||
# -----> Mask Language Model
|
||||
|
||||
input_ids = self.tokenizer.convert_tokens_to_ids(words)
|
||||
input_mask = [1] * len(input_ids)
|
||||
segment_ids = [0] * len(input_ids)
|
||||
masked_token_ids = self.tokenizer.convert_tokens_to_ids(masked_tokens)
|
||||
while len(input_ids) < self.max_words:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
masked_token_ids.append(0)
|
||||
token_labels.append(-1)
|
||||
assert len(input_ids) == self.max_words
|
||||
assert len(input_mask) == self.max_words
|
||||
assert len(segment_ids) == self.max_words
|
||||
assert len(masked_token_ids) == self.max_words
|
||||
assert len(token_labels) == self.max_words
|
||||
|
||||
pairs_text[i] = np.array(input_ids)
|
||||
pairs_mask[i] = np.array(input_mask)
|
||||
pairs_segment[i] = np.array(segment_ids)
|
||||
pairs_masked_text[i] = np.array(masked_token_ids)
|
||||
pairs_token_labels[i] = np.array(token_labels)
|
||||
|
||||
return pairs_text, pairs_mask, pairs_segment, pairs_masked_text, pairs_token_labels, choice_video_ids
|
||||
|
||||
def _get_single_text(self, video_id):
|
||||
rind = random.randint(0, len(self.sentences[video_id]) - 1)
|
||||
caption = self.sentences[video_id][rind]
|
||||
words = self.tokenizer.tokenize(caption)
|
||||
return words
|
||||
|
||||
def _get_video(self, choice_video_ids):
|
||||
video_mask = np.zeros((len(choice_video_ids), self.max_frames), dtype=np.long)
|
||||
max_video_length = [0] * len(choice_video_ids)
|
||||
|
||||
video = np.zeros((len(choice_video_ids), self.max_frames, self.feature_size), dtype=np.float)
|
||||
for i, video_id in enumerate(choice_video_ids):
|
||||
video_slice = self.feature_dict[video_id]
|
||||
|
||||
if self.max_frames < video_slice.shape[0]:
|
||||
video_slice = video_slice[:self.max_frames]
|
||||
|
||||
slice_shape = video_slice.shape
|
||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
|
||||
if len(video_slice) < 1:
|
||||
print("video_id: {}".format(video_id))
|
||||
else:
|
||||
video[i][:slice_shape[0]] = video_slice
|
||||
|
||||
for i, v_length in enumerate(max_video_length):
|
||||
video_mask[i][:v_length] = [1] * v_length
|
||||
|
||||
# Mask Frame Model <-----
|
||||
video_labels_index = [[] for _ in range(len(choice_video_ids))]
|
||||
masked_video = video.copy()
|
||||
for i, video_pair_ in enumerate(masked_video):
|
||||
for j, _ in enumerate(video_pair_):
|
||||
if j < max_video_length[i]:
|
||||
prob = random.random()
|
||||
# mask token with 15% probability
|
||||
if prob < 0.15:
|
||||
masked_video[i][j] = [0.] * video.shape[-1]
|
||||
video_labels_index[i].append(j)
|
||||
else:
|
||||
video_labels_index[i].append(-1)
|
||||
else:
|
||||
video_labels_index[i].append(-1)
|
||||
video_labels_index = np.array(video_labels_index, dtype=np.long)
|
||||
# -----> Mask Frame Model
|
||||
|
||||
return video, video_mask, masked_video, video_labels_index
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.unfold_sentences:
|
||||
video_id, caption = self.sentences_dict[idx]
|
||||
else:
|
||||
video_id, caption = self.csv['video_id'].values[idx], None
|
||||
pairs_text, pairs_mask, pairs_segment, \
|
||||
pairs_masked_text, pairs_token_labels, choice_video_ids = self._get_text(video_id, caption)
|
||||
|
||||
video, video_mask, masked_video, video_labels_index = self._get_video(choice_video_ids)
|
||||
|
||||
return pairs_text, pairs_mask, pairs_segment, video, video_mask, \
|
||||
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index
|
|
@ -7,52 +7,35 @@ from torch.utils.data import Dataset
|
|||
import pandas as pd
|
||||
import os
|
||||
import numpy as np
|
||||
import pickle
|
||||
import re
|
||||
import random
|
||||
import io
|
||||
|
||||
class Youcook_Transcript_NoPair_DataLoader(Dataset):
|
||||
class Youcook_Caption_DataLoader(Dataset):
|
||||
"""Youcook dataset loader."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv,
|
||||
data_path,
|
||||
features_path,
|
||||
caption,
|
||||
tokenizer,
|
||||
min_time=10.0,
|
||||
features_path_3D=None,
|
||||
feature_framerate=1.0,
|
||||
feature_framerate_3D=24.0 / 16.0,
|
||||
max_words=30,
|
||||
min_words=0,
|
||||
n_pair=-1, # -1 for test
|
||||
pad_token="[PAD]",
|
||||
max_frames=100,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
"""
|
||||
self.csv = pd.read_csv(csv)
|
||||
self.features_path_2D = features_path
|
||||
self.features_path_3D = features_path_3D
|
||||
self.caption = caption
|
||||
self.min_time = min_time
|
||||
self.data_dict = pickle.load(open(data_path, 'rb'))
|
||||
self.feature_dict = pickle.load(open(features_path, 'rb'))
|
||||
self.feature_framerate = feature_framerate
|
||||
self.feature_framerate_3D = feature_framerate_3D
|
||||
self.max_words = max_words
|
||||
self.max_frames = max_frames
|
||||
self.min_words = min_words
|
||||
self.tokenizer = tokenizer
|
||||
self.n_pair = n_pair
|
||||
self.pad_token = pad_token
|
||||
self.fps = {'2d': feature_framerate, '3d': feature_framerate_3D}
|
||||
self.feature_path = {'2d': features_path}
|
||||
if features_path_3D != '':
|
||||
self.feature_path['3d'] = features_path_3D
|
||||
|
||||
_feature_file = os.path.join(features_path, self.csv["feature_file_2D"].values[0])
|
||||
self.feature_size = np.load(_feature_file).shape[-1]
|
||||
self.feature_size = self.feature_dict[self.csv["feature_file"].values[0]].shape[-1]
|
||||
|
||||
# Get iterator video ids
|
||||
video_id_list = [itm for itm in self.csv['video_id'].values]
|
||||
|
@ -61,8 +44,8 @@ class Youcook_Transcript_NoPair_DataLoader(Dataset):
|
|||
self.iter2video_pairs_dict = {}
|
||||
iter_idx_ = 0
|
||||
for video_id in video_id_list:
|
||||
caption = self.caption[video_id]
|
||||
n_caption = len(caption['start'])
|
||||
data_dict = self.data_dict[video_id]
|
||||
n_caption = len(data_dict['start'])
|
||||
for sub_id in range(n_caption):
|
||||
self.iter2video_pairs_dict[iter_idx_] = (video_id, sub_id)
|
||||
iter_idx_ += 1
|
||||
|
@ -71,7 +54,7 @@ class Youcook_Transcript_NoPair_DataLoader(Dataset):
|
|||
return len(self.iter2video_pairs_dict)
|
||||
|
||||
def _get_text(self, video_id, sub_id):
|
||||
caption = self.caption[video_id]
|
||||
data_dict = self.data_dict[video_id]
|
||||
k = 1
|
||||
r_ind = [sub_id]
|
||||
|
||||
|
@ -89,10 +72,10 @@ class Youcook_Transcript_NoPair_DataLoader(Dataset):
|
|||
|
||||
for i in range(k):
|
||||
ind = r_ind[i]
|
||||
start_, end_ = caption['start'][ind], caption['end'][ind]
|
||||
start_, end_ = data_dict['start'][ind], data_dict['end'][ind]
|
||||
starts[i], ends[i] = start_, end_
|
||||
total_length_with_CLS = self.max_words - 1
|
||||
words = self.tokenizer.tokenize(caption['transcript'][ind])
|
||||
words = self.tokenizer.tokenize(data_dict['transcript'][ind])
|
||||
|
||||
words = ["[CLS]"] + words
|
||||
if len(words) > total_length_with_CLS:
|
||||
|
@ -156,7 +139,7 @@ class Youcook_Transcript_NoPair_DataLoader(Dataset):
|
|||
pairs_token_labels[i] = np.array(token_labels)
|
||||
|
||||
# For generate captions
|
||||
caption_words = self.tokenizer.tokenize(caption['text'][ind])
|
||||
caption_words = self.tokenizer.tokenize(data_dict['text'][ind])
|
||||
if len(caption_words) > total_length_with_CLS:
|
||||
caption_words = caption_words[:total_length_with_CLS]
|
||||
input_caption_words = ["[CLS]"] + caption_words
|
||||
|
@ -184,16 +167,12 @@ class Youcook_Transcript_NoPair_DataLoader(Dataset):
|
|||
def _get_video(self, idx, s, e):
|
||||
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
|
||||
max_video_length = [0] * len(s)
|
||||
f_title = "feature_file_2D"
|
||||
fps_k = "2d"
|
||||
|
||||
feature_file = os.path.join(self.feature_path[fps_k], self.csv[f_title].values[idx])
|
||||
video_features = np.load(feature_file)
|
||||
|
||||
video_features = self.feature_dict[self.csv["feature_file"].values[idx]]
|
||||
video = np.zeros((len(s), self.max_frames, self.feature_size), dtype=np.float)
|
||||
for i in range(len(s)):
|
||||
start = int(s[i] * self.fps[fps_k])
|
||||
end = int(e[i] * self.fps[fps_k]) + 1
|
||||
start = int(s[i] * self.feature_framerate)
|
||||
end = int(e[i] * self.feature_framerate) + 1
|
||||
video_slice = video_features[start:end]
|
||||
|
||||
if self.max_frames < video_slice.shape[0]:
|
||||
|
@ -202,7 +181,7 @@ class Youcook_Transcript_NoPair_DataLoader(Dataset):
|
|||
slice_shape = video_slice.shape
|
||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
|
||||
if len(video_slice) < 1:
|
||||
print("video_id: {}, start: {}, end: {}".format(feature_file, start, end))
|
||||
print("video_id: {}, start: {}, end: {}".format(self.csv["video_id"].values[idx], start, end))
|
||||
# pass
|
||||
else:
|
||||
video[i][:slice_shape[0]] = video_slice
|
|
@ -7,46 +7,31 @@ from torch.utils.data import Dataset
|
|||
import pandas as pd
|
||||
import os
|
||||
import numpy as np
|
||||
import pickle
|
||||
import random
|
||||
|
||||
class Youcook_NoPair_DataLoader(Dataset):
|
||||
class Youcook_DataLoader(Dataset):
|
||||
"""Youcook dataset loader."""
|
||||
def __init__(
|
||||
self,
|
||||
csv,
|
||||
data_path,
|
||||
features_path,
|
||||
caption,
|
||||
tokenizer,
|
||||
min_time=10.0,
|
||||
features_path_3D=None,
|
||||
feature_framerate=1.0,
|
||||
feature_framerate_3D=24.0 / 16.0,
|
||||
max_words=30,
|
||||
min_words=0,
|
||||
n_pair=-1, # -1 for test
|
||||
pad_token="[PAD]",
|
||||
max_frames=100,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
"""
|
||||
self.csv = pd.read_csv(csv)
|
||||
self.features_path_2D = features_path
|
||||
self.features_path_3D = features_path_3D
|
||||
self.caption = caption
|
||||
self.min_time = min_time
|
||||
self.data_dict = pickle.load(open(data_path, 'rb'))
|
||||
self.feature_dict = pickle.load(open(features_path, 'rb'))
|
||||
self.feature_framerate = feature_framerate
|
||||
self.feature_framerate_3D = feature_framerate_3D
|
||||
self.max_words = max_words
|
||||
self.max_frames = max_frames
|
||||
self.min_words = min_words
|
||||
self.tokenizer = tokenizer
|
||||
self.n_pair = n_pair
|
||||
self.pad_token = pad_token
|
||||
self.fps = {'2d': feature_framerate, '3d': feature_framerate_3D}
|
||||
self.feature_path = {'2d': features_path}
|
||||
if features_path_3D != '':
|
||||
self.feature_path['3d'] = features_path_3D
|
||||
|
||||
# Get iterator video ids
|
||||
video_id_list = [itm for itm in self.csv['video_id'].values]
|
||||
|
@ -55,8 +40,8 @@ class Youcook_NoPair_DataLoader(Dataset):
|
|||
self.iter2video_pairs_dict = {}
|
||||
iter_idx_ = 0
|
||||
for video_id in video_id_list:
|
||||
caption = self.caption[video_id]
|
||||
n_caption = len(caption['start'])
|
||||
data_dict = self.data_dict[video_id]
|
||||
n_caption = len(data_dict['start'])
|
||||
for sub_id in range(n_caption):
|
||||
self.iter2video_pairs_dict[iter_idx_] = (video_id, sub_id)
|
||||
iter_idx_ += 1
|
||||
|
@ -65,9 +50,8 @@ class Youcook_NoPair_DataLoader(Dataset):
|
|||
return len(self.iter2video_pairs_dict)
|
||||
|
||||
def _get_text(self, video_id, sub_id):
|
||||
caption = self.caption[video_id]
|
||||
k = 1
|
||||
r_ind = [sub_id]
|
||||
data_dict = self.data_dict[video_id]
|
||||
k, r_ind = 1, [sub_id]
|
||||
|
||||
starts = np.zeros(k)
|
||||
ends = np.zeros(k)
|
||||
|
@ -79,9 +63,8 @@ class Youcook_NoPair_DataLoader(Dataset):
|
|||
|
||||
for i in range(k):
|
||||
ind = r_ind[i]
|
||||
# Note: n_pair_max=-1 means eval
|
||||
words = self.tokenizer.tokenize(caption['text'][ind])
|
||||
start_, end_ = caption['start'][ind], caption['end'][ind]
|
||||
words = self.tokenizer.tokenize(data_dict['text'][ind])
|
||||
start_, end_ = data_dict['start'][ind], data_dict['end'][ind]
|
||||
starts[i], ends[i] = start_, end_
|
||||
|
||||
words = ["[CLS]"] + words
|
||||
|
@ -151,16 +134,12 @@ class Youcook_NoPair_DataLoader(Dataset):
|
|||
def _get_video(self, idx, s, e):
|
||||
video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
|
||||
max_video_length = [0] * len(s)
|
||||
f_title = "feature_file_2D"
|
||||
fps_k = "2d"
|
||||
|
||||
feature_file = os.path.join(self.feature_path[fps_k], self.csv[f_title].values[idx])
|
||||
video_features = np.load(feature_file)
|
||||
|
||||
video_features = self.feature_dict[self.csv["feature_file"].values[idx]]
|
||||
video = np.zeros((len(s), self.max_frames, video_features.shape[-1]), dtype=np.float)
|
||||
for i in range(len(s)):
|
||||
start = int(s[i] * self.fps[fps_k])
|
||||
end = int(e[i] * self.fps[fps_k]) + 1
|
||||
start = int(s[i] * self.feature_framerate)
|
||||
end = int(e[i] * self.feature_framerate) + 1
|
||||
video_slice = video_features[start:end]
|
||||
|
||||
if self.max_frames < video_slice.shape[0]:
|
||||
|
@ -169,7 +148,7 @@ class Youcook_NoPair_DataLoader(Dataset):
|
|||
slice_shape = video_slice.shape
|
||||
max_video_length[i] = max_video_length[i] if max_video_length[i] > slice_shape[0] else slice_shape[0]
|
||||
if len(video_slice) < 1:
|
||||
print("video_id: {}, start: {}, end: {}".format(feature_file, start, end))
|
||||
print("video_id: {}, start: {}, end: {}".format(self.csv["video_id"].values[idx], start, end))
|
||||
else:
|
||||
video[i][:slice_shape[0]] = video_slice
|
||||
|
|
@ -0,0 +1,410 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (SequentialSampler)
|
||||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
import pickle
|
||||
import logging
|
||||
import time
|
||||
import argparse
|
||||
from modules.tokenization import BertTokenizer
|
||||
from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from modules.modeling import UniVL
|
||||
from modules.optimization import BertAdam
|
||||
from dataloader_howto100m import Youtube_DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from util import get_logger
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
|
||||
global logger
|
||||
|
||||
def get_args(description='UniVL on Pretrain'):
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
|
||||
|
||||
parser.add_argument('--train_csv', type=str, default='data/HowTo100M_v1.csv', help='train csv')
|
||||
parser.add_argument('--features_path', type=str, default='feature', help='feature path for 2D features')
|
||||
parser.add_argument('--data_path', type=str, default='data/data.pickle', help='data pickle file path')
|
||||
|
||||
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
|
||||
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
|
||||
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
|
||||
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
|
||||
parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
|
||||
parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
|
||||
parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
|
||||
parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
|
||||
parser.add_argument('--seed', type=int, default=42, help='random seed')
|
||||
parser.add_argument('--max_words', type=int, default=20, help='')
|
||||
parser.add_argument('--max_frames', type=int, default=100, help='')
|
||||
parser.add_argument('--min_words', type=int, default=0, help='')
|
||||
parser.add_argument('--feature_framerate', type=int, default=1, help='')
|
||||
parser.add_argument('--min_time', type=float, default=5.0, help='Gather small clips')
|
||||
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
|
||||
parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
|
||||
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
|
||||
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
|
||||
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument("--bert_model", default="bert-base-uncased", type=str, required=True,
|
||||
help="Bert pre-trained model")
|
||||
parser.add_argument("--visual_model", default="visual-base", type=str, required=False, help="Visual module")
|
||||
parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
|
||||
parser.add_argument("--decoder_model", default="decoder-base", type=str, required=False, help="Decoder module")
|
||||
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
|
||||
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
|
||||
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
|
||||
parser.add_argument("--world_size", default=0, type=int, help="distribted training")
|
||||
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
|
||||
parser.add_argument('--coef_lr', type=float, default=0.1, help='coefficient for bert branch.')
|
||||
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
|
||||
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether use MIL, has a high priority than use_mil.")
|
||||
|
||||
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
|
||||
parser.add_argument('--visual_num_hidden_layers', type=int, default=6, help="Layer NO. of visual.")
|
||||
parser.add_argument('--cross_num_hidden_layers', type=int, default=2, help="Layer NO. of cross.")
|
||||
parser.add_argument('--decoder_num_hidden_layers', type=int, default=3, help="Layer NO. of decoder.")
|
||||
|
||||
parser.add_argument('--stage_two', action='store_true', help="Whether training with decoder.")
|
||||
parser.add_argument('--pretrain_enhance_vmodal', action='store_true', help="Enhance visual and other modalities when pretraining.")
|
||||
|
||||
parser.add_argument("--load_checkpoint", action="store_true")
|
||||
parser.add_argument("--checkpoint_model", default="pytorch_model.bin.checkpoint", type=str, required=False,
|
||||
help="Save the last model as a checkpoint.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sampled_use_mil: # sample from each video, has a higher priority than use_mil.
|
||||
args.use_mil = True
|
||||
|
||||
# Check paramenters
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
args.gradient_accumulation_steps))
|
||||
if not args.do_pretrain:
|
||||
raise ValueError("`do_pretrain` must be True.")
|
||||
|
||||
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
|
||||
|
||||
args.checkpoint_model = '{}_{}_{}_{}.checkpoint'.format(args.checkpoint_model, args.bert_model, args.max_words, args.max_frames)
|
||||
|
||||
return args
|
||||
|
||||
def set_seed_logger(args):
|
||||
global logger
|
||||
random.seed(args.seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
args.world_size = world_size
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
|
||||
|
||||
if args.local_rank == 0:
|
||||
logger.info("Effective parameters:")
|
||||
for key in sorted(args.__dict__):
|
||||
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
|
||||
|
||||
return args
|
||||
|
||||
def init_device(args, local_rank):
|
||||
global logger
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank)
|
||||
|
||||
n_gpu = torch.cuda.device_count()
|
||||
logger.info("device: {} n_gpu: {}".format(device, n_gpu))
|
||||
args.n_gpu = n_gpu
|
||||
|
||||
if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0:
|
||||
raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format(
|
||||
args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu))
|
||||
|
||||
return device, n_gpu
|
||||
|
||||
def init_model(args, device, n_gpu, local_rank):
|
||||
|
||||
if args.init_model:
|
||||
model_state_dict = torch.load(args.init_model, map_location='cpu')
|
||||
else:
|
||||
model_state_dict = None
|
||||
|
||||
# Prepare model
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
||||
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
|
||||
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
||||
|
||||
model.to(device)
|
||||
|
||||
return model
|
||||
|
||||
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.):
|
||||
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
|
||||
no_decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)]
|
||||
decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)]
|
||||
|
||||
no_decay_bert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." in n]
|
||||
no_decay_nobert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." not in n]
|
||||
|
||||
decay_bert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." in n]
|
||||
decay_nobert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." not in n]
|
||||
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in no_decay_bert_param_tp], 'weight_decay': 0.01, 'lr': args.lr * coef_lr},
|
||||
{'params': [p for n, p in no_decay_nobert_param_tp], 'weight_decay': 0.01},
|
||||
{'params': [p for n, p in decay_bert_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
|
||||
{'params': [p for n, p in decay_nobert_param_tp], 'weight_decay': 0.0}
|
||||
]
|
||||
|
||||
scheduler = None
|
||||
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
|
||||
schedule='warmup_linear', t_total=num_train_optimization_steps, weight_decay=0.01,
|
||||
max_grad_norm=1.0)
|
||||
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
|
||||
output_device=local_rank, find_unused_parameters=True)
|
||||
|
||||
return optimizer, scheduler, model
|
||||
|
||||
def dataloader_pretrain(args, tokenizer, only_sim=False):
|
||||
if args.local_rank == 0:
|
||||
logger.info('Loading captions: {}'.format(args.data_path))
|
||||
data_dict = pickle.load(open(args.data_path, 'rb'))
|
||||
if args.local_rank == 0:
|
||||
logger.info('Done, data_dict length: {}'.format(len(data_dict)))
|
||||
|
||||
dataset = Youtube_DataLoader(
|
||||
csv=args.train_csv,
|
||||
features_path=args.features_path,
|
||||
data_dict=data_dict,
|
||||
min_time=args.min_time,
|
||||
max_words=args.max_words,
|
||||
min_words=args.min_words,
|
||||
feature_framerate=args.feature_framerate,
|
||||
tokenizer=tokenizer,
|
||||
n_pair=args.n_pair,
|
||||
max_frames=args.max_frames,
|
||||
use_mil=args.use_mil,
|
||||
only_sim=only_sim,
|
||||
sampled_use_mil=args.sampled_use_mil,
|
||||
pretrain_enhance_vmodal=args.pretrain_enhance_vmodal,
|
||||
video_dim=args.video_dim,
|
||||
)
|
||||
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size // args.n_gpu,
|
||||
num_workers=args.num_thread_reader,
|
||||
pin_memory=False,
|
||||
shuffle=(sampler is None),
|
||||
sampler=sampler,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
return dataloader, len(dataset), sampler
|
||||
|
||||
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
|
||||
if isinstance(state_dict, dict):
|
||||
cpu_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
cpu_dict[k] = convert_state_dict_type(v)
|
||||
return cpu_dict
|
||||
elif isinstance(state_dict, list):
|
||||
return [convert_state_dict_type(v) for v in state_dict]
|
||||
elif torch.is_tensor(state_dict):
|
||||
return state_dict.type(ttype)
|
||||
else:
|
||||
return state_dict
|
||||
|
||||
def save_model(epoch, args, model, local_rank, type_name="", global_step=-1, optimizer=None):
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
output_model_file = os.path.join(
|
||||
args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
logger.info("Model saved to %s", output_model_file)
|
||||
|
||||
if global_step != -1 and optimizer is not None:
|
||||
state_dict = {
|
||||
'epoch': epoch,
|
||||
'global_step': global_step,
|
||||
'model_state_dict': model_to_save.state_dict(),
|
||||
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
|
||||
}
|
||||
checkpoint_model_file = os.path.join(args.output_dir, args.checkpoint_model)
|
||||
torch.save(state_dict, checkpoint_model_file)
|
||||
logger.info("Checkpoint is saved. use `load_checkpoint` to recovery it.")
|
||||
|
||||
return output_model_file
|
||||
|
||||
def load_model(epoch, args, n_gpu, device, model, global_step=0, model_file=None):
|
||||
if model_file is None or len(model_file) == 0:
|
||||
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
|
||||
|
||||
last_optim_state = None
|
||||
checkpoint_model_file = os.path.join(args.output_dir, args.checkpoint_model)
|
||||
if epoch == -1 and args.load_checkpoint and os.path.exists(checkpoint_model_file):
|
||||
checkpoint_state = torch.load(checkpoint_model_file, map_location='cpu')
|
||||
epoch = checkpoint_state['epoch']
|
||||
global_step = checkpoint_state['global_step']
|
||||
model_state_dict = checkpoint_state['model_state_dict']
|
||||
last_optim_state = checkpoint_state['last_optimizer_state']
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
||||
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
|
||||
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
||||
|
||||
model.to(device)
|
||||
if args.local_rank == 0:
|
||||
logger.info("Checkpoint loaded from %s", checkpoint_model_file)
|
||||
elif os.path.exists(model_file):
|
||||
model_state_dict = torch.load(model_file, map_location='cpu')
|
||||
if args.local_rank == 0:
|
||||
logger.info("Model loaded from %s", model_file)
|
||||
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
||||
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
|
||||
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
||||
|
||||
model.to(device)
|
||||
|
||||
return epoch, global_step, last_optim_state, model
|
||||
|
||||
def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=0):
|
||||
global logger
|
||||
torch.cuda.empty_cache()
|
||||
model.train()
|
||||
log_step = args.n_display
|
||||
start_time = time.time()
|
||||
total_loss = 0
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
batch = tuple(t.to(device=device, non_blocking=True) for t in batch)
|
||||
|
||||
input_ids, input_mask, segment_ids, video, video_mask, \
|
||||
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index,\
|
||||
pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids = batch
|
||||
|
||||
loss = model(input_ids, segment_ids, input_mask, video, video_mask,
|
||||
pairs_masked_text=pairs_masked_text, pairs_token_labels=pairs_token_labels,
|
||||
masked_video=masked_video, video_labels_index=video_labels_index,
|
||||
input_caption_ids=pairs_input_caption_ids, decoder_mask=pairs_decoder_mask,
|
||||
output_caption_ids=pairs_output_caption_ids)
|
||||
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean()
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
loss.backward()
|
||||
|
||||
total_loss += float(loss)
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
global_step += 1
|
||||
if global_step % log_step == 0 and local_rank == 0:
|
||||
logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
|
||||
args.epochs, step + 1,
|
||||
len(train_dataloader), "-".join([str('%.6f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]),
|
||||
float(loss) * args.gradient_accumulation_steps,
|
||||
(time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
|
||||
start_time = time.time()
|
||||
|
||||
total_loss = total_loss / len(train_dataloader)
|
||||
return total_loss, global_step
|
||||
|
||||
def main():
|
||||
global logger
|
||||
args = get_args()
|
||||
args = set_seed_logger(args)
|
||||
device, n_gpu = init_device(args, args.local_rank)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
model = init_model(args, device, n_gpu, args.local_rank)
|
||||
only_sim = model.module._stage_one if hasattr(model, 'module') else model._stage_one
|
||||
|
||||
train_dataloader, train_length, sampler = dataloader_pretrain(args, tokenizer, only_sim=only_sim)
|
||||
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
|
||||
/ args.gradient_accumulation_steps) * args.epochs
|
||||
|
||||
global_step = 0
|
||||
epoch = -1
|
||||
last_optim_state = None
|
||||
if args.load_checkpoint:
|
||||
epoch, global_step, last_optim_state, model = load_model(epoch, args, n_gpu, device, model, global_step=global_step)
|
||||
epoch += 1
|
||||
if args.local_rank == 0:
|
||||
logger.warning("Will continue to epoch: {}".format(epoch))
|
||||
epoch = 0 if epoch < 0 else epoch
|
||||
|
||||
coef_lr = args.coef_lr
|
||||
if args.init_model:
|
||||
coef_lr = 1.0
|
||||
|
||||
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr)
|
||||
if last_optim_state is not None:
|
||||
optimizer.load_state_dict(last_optim_state)
|
||||
|
||||
if args.local_rank == 0:
|
||||
logger.info("***** Running pretraining *****")
|
||||
logger.info(" Num examples = %d", train_length)
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
|
||||
|
||||
iter_ls_ = [itm for itm in range(args.epochs) if itm >= epoch]
|
||||
for epoch in iter_ls_:
|
||||
sampler.set_epoch(epoch)
|
||||
|
||||
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
|
||||
scheduler, global_step, local_rank=args.local_rank)
|
||||
|
||||
if args.local_rank == 0:
|
||||
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
|
||||
save_model(epoch, args, model, args.local_rank, type_name="pretrain", global_step=global_step, optimizer=optimizer)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -4,62 +4,40 @@ from __future__ import unicode_literals
|
|||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (RandomSampler, SequentialSampler, TensorDataset)
|
||||
|
||||
from data_prefetch_unitl import DataLoaderX as DataLoader # Enhanced Loader
|
||||
|
||||
from torch.utils.data import (SequentialSampler)
|
||||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from youcook_transcript_nopair_dataloader import Youcook_Transcript_NoPair_DataLoader
|
||||
from youtube_transcript_dataloader import Youtube_Transcript_DataLoader
|
||||
from rouge import Rouge
|
||||
from nlgeval import compute_metrics, NLGEval
|
||||
from nlgeval import NLGEval
|
||||
import pickle
|
||||
import logging
|
||||
import time
|
||||
import argparse
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import VLBert
|
||||
from pytorch_pretrained_bert.optimization_bert import BertAdam
|
||||
from pytorch_pretrained_bert.beam import Beam
|
||||
|
||||
from modules.tokenization import BertTokenizer
|
||||
from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from modules.modeling import UniVL
|
||||
from modules.optimization import BertAdam
|
||||
from modules.beam import Beam
|
||||
from torch.utils.data import DataLoader
|
||||
from dataloader_youcook_caption import Youcook_Caption_DataLoader
|
||||
from util import get_logger
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
|
||||
global amp_pck_loaded_
|
||||
try:
|
||||
from apex import amp
|
||||
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
|
||||
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
|
||||
# remove the need for this code, but it is still valid.
|
||||
amp.register_half_function(torch, 'einsum')
|
||||
amp_pck_loaded_ = True
|
||||
except ImportError:
|
||||
amp_pck_loaded_ = False
|
||||
|
||||
def get_logger(filename=None):
|
||||
logger = logging.getLogger('logger')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO)
|
||||
if filename is not None:
|
||||
handler = logging.FileHandler(filename)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
|
||||
logging.getLogger().addHandler(handler)
|
||||
return logger
|
||||
|
||||
global logger
|
||||
|
||||
def get_args(description='Youtube-Text-Video'):
|
||||
def get_args(description='UniVL on Caption Task'):
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument('--train_csv', type=str, default='data/HowTo100M_v1.csv', help='train csv')
|
||||
parser.add_argument('--features_path_2D', type=str, default='feature_2d', help='feature path for 2D features')
|
||||
parser.add_argument('--features_path_3D', type=str, default='feature_3d', help='feature path for 3D features')
|
||||
parser.add_argument('--caption_path', type=str, default='data/caption.pickle', help='caption pickle file path')
|
||||
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
|
||||
|
||||
parser.add_argument('--train_csv', type=str, default='data/youcookii_singlef_train.csv', help='')
|
||||
parser.add_argument('--val_csv', type=str, default='data/youcookii_singlef_val.csv', help='')
|
||||
parser.add_argument('--data_path', type=str, default='data/youcookii_caption_transcript.pickle',
|
||||
help='caption and transcription pickle file path')
|
||||
parser.add_argument('--features_path', type=str, default='data/youcookii_videos_feature.pickle',
|
||||
help='feature path for 2D features')
|
||||
|
||||
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
|
||||
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
|
||||
|
@ -72,7 +50,6 @@ def get_args(description='Youtube-Text-Video'):
|
|||
parser.add_argument('--seed', type=int, default=42, help='random seed')
|
||||
parser.add_argument('--max_words', type=int, default=20, help='')
|
||||
parser.add_argument('--max_frames', type=int, default=100, help='')
|
||||
parser.add_argument('--min_words', type=int, default=0, help='')
|
||||
parser.add_argument('--feature_framerate', type=int, default=1, help='')
|
||||
parser.add_argument('--min_time', type=float, default=5.0, help='Gather small clips')
|
||||
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
|
||||
|
@ -80,34 +57,16 @@ def get_args(description='Youtube-Text-Video'):
|
|||
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
|
||||
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
|
||||
|
||||
parser.add_argument('--youcook', type=int, default=0, help='Train on YouCook2 data')
|
||||
parser.add_argument('--eval_youcook', type=int, default=0, help='Evaluate on YouCook2 data')
|
||||
parser.add_argument('--youcook_train_csv', type=str, default='data/youcookii_singlef_train.csv', help='')
|
||||
parser.add_argument('--youcook_val_csv', type=str, default='data/youcookii_singlef_val.csv', help='')
|
||||
parser.add_argument('--youcook_caption_path', type=str, default='data/youcookii_caption_transcript.pickle', help='youcookii caption and transcription pickle file path')
|
||||
parser.add_argument('--youcook_features_path_2D', type=str, default='data/youcookii_videos_feature2d', help='youcookii feature path for 2D features')
|
||||
parser.add_argument('--youcook_features_path_3D', type=str, default='data/youcookii_videos_feature3d', help='youcookii feature path for 3D features')
|
||||
|
||||
parser.add_argument('--pad_token', type=str, default='[PAD]', help='')
|
||||
|
||||
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
|
||||
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese.")
|
||||
parser.add_argument("--visual_bert_model", default="visual-bert-base", type=str, required=False,
|
||||
help="VisualBert pre-trained model selected in the list: visual-bert-base, visual-bert-large")
|
||||
parser.add_argument("--cross_bert_model", default="cross-bert-base", type=str, required=False,
|
||||
help="CrossBert pre-trained model selected in the list: cross-bert-base, cross-bert-large")
|
||||
parser.add_argument("--decoder_bert_model", default="decoder-bert-base", type=str, required=False,
|
||||
help="DecoderBert pre-trained model selected in the list: decoder-bert-base, decoder-bert-large")
|
||||
parser.add_argument("--bert_model", default="bert-base-uncased", type=str, required=True, help="Bert pre-trained model")
|
||||
parser.add_argument("--visual_model", default="visual-base", type=str, required=False, help="Visual module")
|
||||
parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
|
||||
parser.add_argument("--decoder_model", default="decoder-base", type=str, required=False, help="Decoder module")
|
||||
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
|
||||
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.");
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
|
||||
|
@ -121,52 +80,32 @@ def get_args(description='Youtube-Text-Video'):
|
|||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
|
||||
parser.add_argument("--task_type", default="caption", type=str, help="Point the task `retrieval` or `caption` to finetune.")
|
||||
parser.add_argument("--task_type", default="caption", type=str, help="Point the task `caption` to finetune.")
|
||||
parser.add_argument("--datatype", default="youcook", type=str, help="Point the dataset `youcook` to finetune.")
|
||||
|
||||
parser.add_argument("--world_size", default=0, type=int, help="distribted training")
|
||||
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
|
||||
parser.add_argument('--coef_lr', type=float, default=0.1, help='coefficient for bert branch.')
|
||||
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
|
||||
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether use MIL, has a high priority than use_mil.")
|
||||
|
||||
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
|
||||
parser.add_argument('--visual_num_hidden_layers', type=int, default=1, help="Layer NO. of visual.")
|
||||
parser.add_argument('--visual_num_hidden_layers', type=int, default=6, help="Layer NO. of visual.")
|
||||
parser.add_argument('--cross_num_hidden_layers', type=int, default=2, help="Layer NO. of cross.")
|
||||
parser.add_argument('--decoder_num_hidden_layers', type=int, default=1, help="Layer NO. of decoder.")
|
||||
|
||||
parser.add_argument('--cross_model', action='store_true', help="Whether training with decoder.")
|
||||
parser.add_argument('--pretrain_with_joint_sim', action='store_true', help="Whether using joint embedding when pretraining.")
|
||||
parser.add_argument('--pretrain_enhance_vmodal', action='store_true', help="Enhance visual and other modalities when pretraining.")
|
||||
parser.add_argument('--without_sim_in_decoder', action='store_true', help="Whether align in decoder when training.")
|
||||
parser.add_argument('--pretrain_without_decoder', action='store_true', help="Whether ignore decode when pretraining.")
|
||||
|
||||
parser.add_argument("--load_checkpoint", action="store_true")
|
||||
parser.add_argument("--checkpoint_model", default="pytorch_model.bin.checkpoint", type=str, required=False,
|
||||
help="Save the last model as a checkpoint.")
|
||||
parser.add_argument('--decoder_num_hidden_layers', type=int, default=3, help="Layer NO. of decoder.")
|
||||
|
||||
parser.add_argument('--stage_two', action='store_true', help="Whether training with decoder.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sampled_use_mil: # sample from each video, has a high priority than use_mil.
|
||||
args.use_mil = True
|
||||
if args.do_pretrain is False:
|
||||
args.pretrain_without_decoder = False
|
||||
|
||||
global amp_pck_loaded_
|
||||
if args.fp16:
|
||||
if not amp_pck_loaded_:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
# Check paramenters
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
args.gradient_accumulation_steps))
|
||||
if not args.do_pretrain and not args.do_train and not args.do_eval:
|
||||
raise ValueError("At least one of `do_pretrain` or `do_train` or `do_eval` must be True.")
|
||||
if not args.do_train and not args.do_eval:
|
||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||
|
||||
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
|
||||
|
||||
args.checkpoint_model = '{}_{}_{}_{}.checkpoint'.format(args.checkpoint_model, args.bert_model, args.max_words, args.max_frames)
|
||||
|
||||
return args
|
||||
|
||||
def set_seed_logger(args):
|
||||
|
@ -180,24 +119,22 @@ def set_seed_logger(args):
|
|||
torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
# torch.multiprocessing.set_sharing_strategy('file_system') # RuntimeError: received 0 items of ancdata.
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
local_rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(local_rank)
|
||||
args.local_rank = local_rank
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
args.world_size = world_size
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
|
||||
|
||||
if local_rank == 0:
|
||||
if args.local_rank == 0:
|
||||
logger.info("Effective parameters:")
|
||||
for key in sorted(args.__dict__):
|
||||
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
|
||||
|
||||
return args, world_size, local_rank
|
||||
return args
|
||||
|
||||
def init_device(args, local_rank):
|
||||
global logger
|
||||
|
@ -223,7 +160,7 @@ def init_model(args, device, n_gpu, local_rank):
|
|||
|
||||
# Prepare model
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
||||
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
|
||||
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
|
||||
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
||||
|
||||
model.to(device)
|
||||
|
@ -259,33 +196,19 @@ def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, loc
|
|||
schedule='warmup_linear', t_total=num_train_optimization_steps, weight_decay=0.01,
|
||||
max_grad_norm=1.0)
|
||||
|
||||
if args.fp16:
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if n_gpu > 1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
|
||||
output_device=local_rank, find_unused_parameters=True)
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
|
||||
output_device=local_rank, find_unused_parameters=True)
|
||||
|
||||
return optimizer, scheduler, model
|
||||
|
||||
def dataloader_youcook_train(args, tokenizer):
|
||||
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
|
||||
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
|
||||
logger.info('Done, caption length: {}'.format(len(caption)))
|
||||
|
||||
youcook_dataset = Youcook_Transcript_NoPair_DataLoader(
|
||||
csv=args.youcook_train_csv,
|
||||
features_path=args.youcook_features_path_2D,
|
||||
features_path_3D=args.youcook_features_path_3D,
|
||||
caption=caption,
|
||||
min_time=args.min_time,
|
||||
youcook_dataset = Youcook_Caption_DataLoader(
|
||||
csv=args.train_csv,
|
||||
data_path=args.data_path,
|
||||
features_path=args.features_path,
|
||||
max_words=args.max_words,
|
||||
min_words=args.min_words,
|
||||
feature_framerate=args.feature_framerate,
|
||||
tokenizer=tokenizer,
|
||||
n_pair=-1,
|
||||
pad_token=args.pad_token,
|
||||
max_frames=args.max_frames,
|
||||
)
|
||||
|
||||
|
@ -303,22 +226,13 @@ def dataloader_youcook_train(args, tokenizer):
|
|||
return dataloader, len(youcook_dataset), train_sampler
|
||||
|
||||
def dataloader_youcook_test(args, tokenizer):
|
||||
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
|
||||
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
|
||||
logger.info('Done, caption length: {}'.format(len(caption)))
|
||||
|
||||
youcook_testset = Youcook_Transcript_NoPair_DataLoader(
|
||||
csv=args.youcook_val_csv,
|
||||
features_path=args.youcook_features_path_2D,
|
||||
features_path_3D=args.youcook_features_path_3D,
|
||||
caption=caption,
|
||||
min_time=args.min_time,
|
||||
youcook_testset = Youcook_Caption_DataLoader(
|
||||
csv=args.val_csv,
|
||||
data_path=args.data_path,
|
||||
features_path=args.features_path,
|
||||
max_words=args.max_words,
|
||||
min_words=args.min_words,
|
||||
feature_framerate=args.feature_framerate,
|
||||
tokenizer=tokenizer,
|
||||
n_pair=-1,
|
||||
pad_token=args.pad_token,
|
||||
max_frames=args.max_frames,
|
||||
)
|
||||
|
||||
|
@ -331,47 +245,10 @@ def dataloader_youcook_test(args, tokenizer):
|
|||
pin_memory=False,
|
||||
)
|
||||
|
||||
logger.info('YoucookII validation pairs: {}'.format(len(youcook_testset)))
|
||||
if args.local_rank == 0:
|
||||
logger.info('YoucookII validation pairs: {}'.format(len(youcook_testset)))
|
||||
return dataloader_youcook, len(youcook_testset)
|
||||
|
||||
def dataloader_pretrain(args, tokenizer, only_sim=False):
|
||||
logger.info('Loading captions: {}'.format(args.caption_path))
|
||||
caption = pickle.load(open(args.caption_path, 'rb'))
|
||||
logger.info('Done, caption length: {}'.format(len(caption)))
|
||||
|
||||
dataset = Youtube_Transcript_DataLoader(
|
||||
csv=args.train_csv,
|
||||
features_path=args.features_path_2D,
|
||||
features_path_3D=args.features_path_3D,
|
||||
caption=caption,
|
||||
min_time=args.min_time,
|
||||
max_words=args.max_words,
|
||||
min_words=args.min_words,
|
||||
feature_framerate=args.feature_framerate,
|
||||
tokenizer=tokenizer,
|
||||
n_pair=args.n_pair,
|
||||
pad_token=args.pad_token,
|
||||
max_frames=args.max_frames,
|
||||
use_mil=args.use_mil,
|
||||
only_sim=only_sim,
|
||||
sampled_use_mil=args.sampled_use_mil,
|
||||
pretrain_enhance_vmodal=args.pretrain_enhance_vmodal,
|
||||
video_dim=args.video_dim,
|
||||
)
|
||||
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size // args.n_gpu,
|
||||
num_workers=args.num_thread_reader,
|
||||
pin_memory=False,
|
||||
shuffle=(sampler is None),
|
||||
sampler=sampler,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
return dataloader, len(dataset), sampler
|
||||
|
||||
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
|
||||
if isinstance(state_dict, dict):
|
||||
cpu_dict = OrderedDict()
|
||||
|
@ -385,63 +262,31 @@ def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
|
|||
else:
|
||||
return state_dict
|
||||
|
||||
def save_model(epoch, args, model, local_rank, type_name="", global_step=-1, optimizer=None):
|
||||
def save_model(epoch, args, model, type_name=""):
|
||||
# Only save the model it-self
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
output_model_file = os.path.join(
|
||||
args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
|
||||
torch.save(model_to_save.state_dict(), output_model_file)
|
||||
logger.info("Model saved to %s", output_model_file)
|
||||
|
||||
if global_step != -1 and optimizer is not None:
|
||||
amp_state_dict = {}
|
||||
if args.fp16:
|
||||
amp_state_dict = amp.state_dict()
|
||||
state_dict = {
|
||||
'epoch': epoch,
|
||||
'global_step': global_step,
|
||||
'model_state_dict': model_to_save.state_dict(),
|
||||
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
|
||||
'amp_state_dict': amp_state_dict,
|
||||
}
|
||||
checkpoint_model_file = os.path.join(args.output_dir, args.checkpoint_model)
|
||||
torch.save(state_dict, checkpoint_model_file)
|
||||
logger.info("Checkpoint is saved. use `load_checkpoint` to recovery it.")
|
||||
|
||||
return output_model_file
|
||||
|
||||
def load_model(epoch, args, n_gpu, device, model, global_step=0, model_file=None):
|
||||
def load_model(epoch, args, n_gpu, device, model_file=None):
|
||||
if model_file is None or len(model_file) == 0:
|
||||
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
|
||||
|
||||
last_optim_state = None
|
||||
amp_state_dict = None
|
||||
checkpoint_model_file = os.path.join(args.output_dir, args.checkpoint_model)
|
||||
if epoch == -1 and args.load_checkpoint and os.path.exists(checkpoint_model_file):
|
||||
checkpoint_state = torch.load(checkpoint_model_file, map_location='cpu')
|
||||
epoch = checkpoint_state['epoch']
|
||||
global_step = checkpoint_state['global_step']
|
||||
model_state_dict = checkpoint_state['model_state_dict']
|
||||
last_optim_state = checkpoint_state['last_optimizer_state']
|
||||
if args.fp16 and 'amp_state_dict' in checkpoint_state:
|
||||
amp_state_dict = checkpoint_state['amp_state_dict']
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
||||
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
|
||||
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
||||
|
||||
model.to(device)
|
||||
logger.info("Checkpoint loaded from %s", checkpoint_model_file)
|
||||
elif os.path.exists(model_file):
|
||||
if os.path.exists(model_file):
|
||||
model_state_dict = torch.load(model_file, map_location='cpu')
|
||||
logger.info("Model loaded from %s", model_file)
|
||||
if args.local_rank == 0:
|
||||
logger.info("Model loaded from %s", model_file)
|
||||
# Prepare model
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
||||
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
|
||||
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
|
||||
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
||||
|
||||
model.to(device)
|
||||
|
||||
return epoch, global_step, last_optim_state, amp_state_dict, model
|
||||
else:
|
||||
model = None
|
||||
return model
|
||||
|
||||
def train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer, scheduler,
|
||||
global_step, nlgEvalObj=None, local_rank=0):
|
||||
|
@ -473,19 +318,12 @@ def train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu,
|
|||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
loss.backward()
|
||||
|
||||
total_loss += float(loss)
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step() # Update learning rate schedule
|
||||
|
@ -602,12 +440,12 @@ def collect_hypothesis_and_scores(inst_dec_beams, n_best):
|
|||
return all_hyp, all_scores
|
||||
# >----------------------------------------
|
||||
|
||||
def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=None, nlgEvalObj=None, test_set=None):
|
||||
def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, nlgEvalObj=None, test_set=None):
|
||||
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module.to(device)
|
||||
|
||||
if model._choice_sim:
|
||||
if model._stage_one:
|
||||
return 0.
|
||||
|
||||
all_result_lists = []
|
||||
|
@ -628,12 +466,12 @@ def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=
|
|||
n_inst, len_s, d_h = sequence_output.size()
|
||||
_, len_v, v_h = visual_output.size()
|
||||
|
||||
decoder = model.decoder_caption # This is a decoder function
|
||||
decoder = model.decoder_caption
|
||||
|
||||
# Note: shaped first, then decoder need the parameter shaped=True
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1]) # [batch_size*n_pair, self.max_words]
|
||||
input_mask = input_mask.view(-1, input_mask.shape[-1]) # [batch_size*n_pair, self.max_words]
|
||||
video_mask = video_mask.view(-1, video_mask.shape[-1]) # [batch_size*n_pair, self.max_frames]
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
input_mask = input_mask.view(-1, input_mask.shape[-1])
|
||||
video_mask = video_mask.view(-1, video_mask.shape[-1])
|
||||
|
||||
sequence_output_rpt = sequence_output.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)
|
||||
visual_output_rpt = visual_output.repeat(1, n_bm, 1).view(n_inst * n_bm, len_v, v_h)
|
||||
|
@ -696,7 +534,7 @@ def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=
|
|||
writer.write("{}\t{}\t{}\n".format("video_id", "start_time", "caption"))
|
||||
for idx, pre_txt in enumerate(all_result_lists):
|
||||
video_id, sub_id = test_set.iter2video_pairs_dict[idx]
|
||||
start_time = test_set.caption[video_id]['start'][sub_id]
|
||||
start_time = test_set.data_dict[video_id]['start'][sub_id]
|
||||
writer.write("{}\t{}\t{}\n".format(video_id, start_time, pre_txt))
|
||||
logger.info("File of complete results is saved in {}".format(hyp_path))
|
||||
|
||||
|
@ -711,17 +549,7 @@ def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=
|
|||
for ground_txt in all_caption_lists:
|
||||
writer.write(ground_txt + "\n")
|
||||
|
||||
# Filter out hyps of 0 length
|
||||
hyps_and_refs = zip(all_result_lists, all_caption_lists)
|
||||
hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0
|
||||
and len([" ".join(sub_s.split()) for sub_s in _[0].split(".") if len(sub_s) > 0]) > 0]
|
||||
all_result_lists, all_caption_lists = zip(*hyps_and_refs)
|
||||
|
||||
# Evaluate
|
||||
metrics_dict = rougeObj.get_scores(hyps=all_result_lists, refs=all_caption_lists, avg=True, ignore_empty=True)
|
||||
logger.info(">>> rouge_1f: {:.4f}, rouge_2f: {:.4f}, rouge_lf: {:.4f}".
|
||||
format(metrics_dict["rouge-1"]["f"], metrics_dict["rouge-2"]["f"], metrics_dict["rouge-l"]["f"]))
|
||||
|
||||
metrics_nlg = nlgEvalObj.compute_metrics(ref_list=[all_caption_lists], hyp_list=all_result_lists)
|
||||
logger.info(">>> BLEU_1: {:.4f}, BLEU_2: {:.4f}, BLEU_3: {:.4f}, BLEU_4: {:.4f}".
|
||||
format(metrics_nlg["Bleu_1"], metrics_nlg["Bleu_2"], metrics_nlg["Bleu_3"], metrics_nlg["Bleu_4"]))
|
||||
|
@ -736,83 +564,34 @@ DATALOADER_DICT["youcook"] = {"train":dataloader_youcook_train, "val":dataloader
|
|||
def main():
|
||||
global logger
|
||||
args = get_args()
|
||||
args, world_size, local_rank = set_seed_logger(args)
|
||||
device, n_gpu = init_device(args, local_rank)
|
||||
args = set_seed_logger(args)
|
||||
device, n_gpu = init_device(args, args.local_rank)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
model = init_model(args, device, n_gpu, args.local_rank)
|
||||
|
||||
assert args.task_type == "caption"
|
||||
nlgEvalObj = NLGEval(no_overlap=False, no_skipthoughts=True, no_glove=True, metrics_to_omit=None)
|
||||
|
||||
assert args.datatype in DATALOADER_DICT
|
||||
test_dataloader, test_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer)
|
||||
if args.local_rank == 0:
|
||||
logger.info("***** Running test *****")
|
||||
logger.info(" Num examples = %d", test_length)
|
||||
logger.info(" Batch size = %d", args.batch_size_val)
|
||||
logger.info(" Num steps = %d", len(test_dataloader))
|
||||
|
||||
if args.do_train:
|
||||
args.task_type = "caption"
|
||||
model = init_model(args, device, n_gpu, local_rank)
|
||||
only_sim = model.module._choice_sim if hasattr(model, 'module') else model._choice_sim
|
||||
|
||||
if args.task_type == "caption":
|
||||
rougeObj = Rouge()
|
||||
nlgEvalObj = NLGEval(no_overlap=False, no_skipthoughts=True, no_glove=True, metrics_to_omit=None)
|
||||
|
||||
datatype = "youcook"
|
||||
assert datatype in DATALOADER_DICT
|
||||
if args.do_pretrain is False:
|
||||
test_dataloader, test_length = DATALOADER_DICT[datatype]["val"](args, tokenizer)
|
||||
if local_rank == 0:
|
||||
logger.info("***** Running test *****")
|
||||
logger.info(" Num examples = %d", test_length)
|
||||
logger.info(" Batch size = %d", args.batch_size_val)
|
||||
logger.info(" Num steps = %d", len(test_dataloader))
|
||||
|
||||
if args.do_pretrain:
|
||||
train_dataloader, train_length, sampler = dataloader_pretrain(args, tokenizer, only_sim=only_sim)
|
||||
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
|
||||
/ args.gradient_accumulation_steps) * args.epochs
|
||||
|
||||
global_step = 0
|
||||
epoch = -1
|
||||
last_optim_state = None
|
||||
amp_state_dict = None
|
||||
if args.load_checkpoint:
|
||||
epoch, global_step, last_optim_state, amp_state_dict, model = load_model(epoch, args, n_gpu, device, model, global_step=global_step)
|
||||
epoch += 1
|
||||
if local_rank == 0:
|
||||
logger.warning("Will continue to epoch: {}".format(epoch))
|
||||
epoch = 0 if epoch < 0 else epoch
|
||||
|
||||
coef_lr = args.coef_lr
|
||||
if args.init_model:
|
||||
coef_lr = 1.0
|
||||
|
||||
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=coef_lr)
|
||||
if last_optim_state is not None:
|
||||
optimizer.load_state_dict(last_optim_state)
|
||||
if amp_state_dict is not None:
|
||||
amp.load_state_dict(amp_state_dict)
|
||||
|
||||
if local_rank == 0:
|
||||
logger.info("***** Running pretraining *****")
|
||||
logger.info(" Num examples = %d", train_length)
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
|
||||
|
||||
iter_ls_ = [itm for itm in range(args.epochs) if itm >= epoch]
|
||||
for epoch in iter_ls_:
|
||||
sampler.set_epoch(epoch)
|
||||
|
||||
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer,
|
||||
scheduler, global_step, nlgEvalObj=nlgEvalObj, local_rank=local_rank)
|
||||
|
||||
if local_rank == 0:
|
||||
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
|
||||
save_model(epoch, args, model, local_rank, type_name="pretrain", global_step=global_step, optimizer=optimizer)
|
||||
elif args.do_train:
|
||||
|
||||
train_dataloader, train_length, train_sampler = DATALOADER_DICT[datatype]["train"](args, tokenizer)
|
||||
train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer)
|
||||
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
|
||||
/ args.gradient_accumulation_steps) * args.epochs
|
||||
|
||||
coef_lr = args.coef_lr
|
||||
if args.init_model:
|
||||
coef_lr = 1.0
|
||||
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=coef_lr)
|
||||
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr)
|
||||
|
||||
if local_rank == 0:
|
||||
if args.local_rank == 0:
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", train_length)
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
|
@ -825,13 +604,13 @@ def main():
|
|||
train_sampler.set_epoch(epoch)
|
||||
|
||||
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, tokenizer, device, n_gpu, optimizer,
|
||||
scheduler, global_step, nlgEvalObj=nlgEvalObj, local_rank=local_rank)
|
||||
scheduler, global_step, nlgEvalObj=nlgEvalObj, local_rank=args.local_rank)
|
||||
|
||||
if local_rank == 0:
|
||||
if args.local_rank == 0:
|
||||
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
|
||||
output_model_file = save_model(epoch, args, model, local_rank, type_name="")
|
||||
output_model_file = save_model(epoch, args, model, type_name="")
|
||||
if epoch > 0:
|
||||
Bleu_4 = eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=rougeObj, nlgEvalObj=nlgEvalObj)
|
||||
Bleu_4 = eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, nlgEvalObj=nlgEvalObj)
|
||||
if best_score <= Bleu_4:
|
||||
best_score = Bleu_4
|
||||
best_output_model_file = output_model_file
|
||||
|
@ -839,12 +618,12 @@ def main():
|
|||
else:
|
||||
logger.warning("Skip the evaluation after {}-th epoch.".format(epoch+1))
|
||||
|
||||
if local_rank == 0:
|
||||
_, _, _, _, model = load_model(-1, args, n_gpu, device, model, model_file=best_output_model_file)
|
||||
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=rougeObj, nlgEvalObj=nlgEvalObj)
|
||||
if args.local_rank == 0:
|
||||
model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
|
||||
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, nlgEvalObj=nlgEvalObj)
|
||||
elif args.do_eval:
|
||||
if local_rank == 0:
|
||||
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, rougeObj=rougeObj, nlgEvalObj=nlgEvalObj)
|
||||
if args.local_rank == 0:
|
||||
eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, nlgEvalObj=nlgEvalObj)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -4,43 +4,39 @@ from __future__ import unicode_literals
|
|||
from __future__ import print_function
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (RandomSampler, SequentialSampler, TensorDataset)
|
||||
|
||||
# from torch.utils.data import DataLoader
|
||||
from data_prefetch_unitl import DataLoaderX as DataLoader # Enhanced Loader
|
||||
|
||||
from torch.utils.data import (SequentialSampler)
|
||||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
from youcook_nopair_dataloader import Youcook_NoPair_DataLoader
|
||||
from metrics import compute_metrics, print_computed_metrics
|
||||
from metrics import compute_metrics
|
||||
import pickle
|
||||
import logging
|
||||
import time
|
||||
import argparse
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import VLBert
|
||||
from pytorch_pretrained_bert.optimization_bert import BertAdam
|
||||
from util import parallel_apply
|
||||
|
||||
def get_logger(filename=None):
|
||||
logger = logging.getLogger('logger')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO)
|
||||
if filename is not None:
|
||||
handler = logging.FileHandler(filename)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
|
||||
logging.getLogger().addHandler(handler)
|
||||
return logger
|
||||
from modules.tokenization import BertTokenizer
|
||||
from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from modules.modeling import UniVL
|
||||
from modules.optimization import BertAdam
|
||||
from torch.utils.data import DataLoader
|
||||
from util import parallel_apply, get_logger
|
||||
from dataloader_youcook_retrieval import Youcook_DataLoader
|
||||
from dataloader_msrvtt_retrieval import MSRVTT_DataLoader
|
||||
from dataloader_msrvtt_retrieval import MSRVTT_TrainDataLoader
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
|
||||
global logger
|
||||
|
||||
def get_args(description='Youtube-Text-Video'):
|
||||
def get_args(description='UniVL on Retrieval Task'):
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
|
||||
|
||||
parser.add_argument('--train_csv', type=str, default='data/youcookii_singlef_train.csv', help='')
|
||||
parser.add_argument('--val_csv', type=str, default='data/youcookii_singlef_val.csv', help='')
|
||||
parser.add_argument('--data_path', type=str, default='data/youcookii_caption.pickle', help='data pickle file path')
|
||||
parser.add_argument('--features_path', type=str, default='data/youcookii_videos_feature.pickle', help='feature path')
|
||||
|
||||
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
|
||||
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
|
||||
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
|
||||
|
@ -52,42 +48,23 @@ def get_args(description='Youtube-Text-Video'):
|
|||
parser.add_argument('--seed', type=int, default=42, help='random seed')
|
||||
parser.add_argument('--max_words', type=int, default=20, help='')
|
||||
parser.add_argument('--max_frames', type=int, default=100, help='')
|
||||
parser.add_argument('--min_words', type=int, default=0, help='')
|
||||
parser.add_argument('--feature_framerate', type=int, default=1, help='')
|
||||
parser.add_argument('--min_time', type=float, default=5.0, help='Gather small clips')
|
||||
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
|
||||
parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
|
||||
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
|
||||
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
|
||||
|
||||
parser.add_argument('--youcook', type=int, default=0, help='Train on YouCook2 data')
|
||||
parser.add_argument('--eval_youcook', type=int, default=0, help='Evaluate on YouCook2 data')
|
||||
parser.add_argument('--youcook_train_csv', type=str, default='data/youcookii_singlef_train.csv', help='')
|
||||
parser.add_argument('--youcook_val_csv', type=str, default='data/youcookii_singlef_val.csv', help='')
|
||||
parser.add_argument('--youcook_caption_path', type=str, default='data/youcookii_caption.pickle', help='youcookii caption pickle file path')
|
||||
parser.add_argument('--youcook_features_path_2D', type=str, default='data/youcookii_videos_feature2d', help='youcookii 2D feature path')
|
||||
parser.add_argument('--youcook_features_path_3D', type=str, default='data/youcookii_videos_feature3d', help='youcookii 3D feature path')
|
||||
|
||||
parser.add_argument('--pad_token', type=str, default='[PAD]', help='')
|
||||
|
||||
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
|
||||
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese.")
|
||||
parser.add_argument("--visual_bert_model", default="visual-bert-base", type=str, required=False,
|
||||
help="VisualBert pre-trained model selected in the list: visual-bert-base, visual-bert-large")
|
||||
parser.add_argument("--cross_bert_model", default="cross-bert-base", type=str, required=False,
|
||||
help="CrossBert pre-trained model selected in the list: cross-bert-base, cross-bert-large")
|
||||
parser.add_argument("--decoder_bert_model", default="decoder-bert-base", type=str, required=False,
|
||||
help="DecoderBert pre-trained model selected in the list: decoder-bert-base, decoder-bert-large")
|
||||
parser.add_argument("--bert_model", default="bert-base-uncased", type=str, required=True,
|
||||
help="Bert pre-trained model")
|
||||
parser.add_argument("--visual_model", default="visual-base", type=str, required=False, help="Visual module")
|
||||
parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
|
||||
parser.add_argument("--decoder_model", default="decoder-base", type=str, required=False, help="Decoder module")
|
||||
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
|
||||
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.");
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
|
||||
|
@ -103,28 +80,29 @@ def get_args(description='Youtube-Text-Video'):
|
|||
|
||||
parser.add_argument("--task_type", default="retrieval", type=str, help="Point the task `retrieval` to finetune.")
|
||||
parser.add_argument("--datatype", default="youcook", type=str, help="Point the dataset `youcook` to finetune.")
|
||||
|
||||
parser.add_argument("--world_size", default=0, type=int, help="distribted training")
|
||||
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
|
||||
parser.add_argument('--coef_lr', type=float, default=0.1, help='coefficient for bert branch.')
|
||||
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
|
||||
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether MIL, has a high priority than use_mil.")
|
||||
|
||||
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
|
||||
parser.add_argument('--visual_num_hidden_layers', type=int, default=1, help="Layer NO. of visual.")
|
||||
parser.add_argument('--visual_num_hidden_layers', type=int, default=6, help="Layer NO. of visual.")
|
||||
parser.add_argument('--cross_num_hidden_layers', type=int, default=2, help="Layer NO. of cross.")
|
||||
parser.add_argument('--decoder_num_hidden_layers', type=int, default=1, help="Layer NO. of decoder.")
|
||||
parser.add_argument('--decoder_num_hidden_layers', type=int, default=3, help="Layer NO. of decoder.")
|
||||
|
||||
parser.add_argument('--train_sim_after_cross', action='store_true', help="Test retrieval after cross encoder.")
|
||||
parser.add_argument('--expand_msrvtt_sentences', action='store_true', help="")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.sampled_use_mil: # sample from each video, has a high priority than use_mil.
|
||||
args.use_mil = True
|
||||
|
||||
# Check paramenters
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
args.gradient_accumulation_steps))
|
||||
if not args.do_pretrain and not args.do_train and not args.do_eval:
|
||||
raise ValueError("At least one of `do_pretrain` or `do_train` or `do_eval` must be True.")
|
||||
if not args.do_train and not args.do_eval:
|
||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||
|
||||
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
|
||||
|
||||
|
@ -142,17 +120,27 @@ def set_seed_logger(args):
|
|||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
args.world_size = world_size
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
|
||||
logger.info("Effective parameters:")
|
||||
for key in sorted(args.__dict__):
|
||||
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
|
||||
|
||||
def init_device(args):
|
||||
if args.local_rank == 0:
|
||||
logger.info("Effective parameters:")
|
||||
for key in sorted(args.__dict__):
|
||||
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
|
||||
|
||||
return args
|
||||
|
||||
def init_device(args, local_rank):
|
||||
global logger
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank)
|
||||
|
||||
n_gpu = torch.cuda.device_count()
|
||||
logger.info("device: {} n_gpu: {}".format(device, n_gpu))
|
||||
args.n_gpu = n_gpu
|
||||
|
@ -163,7 +151,7 @@ def init_device(args):
|
|||
|
||||
return device, n_gpu
|
||||
|
||||
def init_model(args, device, n_gpu):
|
||||
def init_model(args, device, n_gpu, local_rank):
|
||||
|
||||
if args.init_model:
|
||||
model_state_dict = torch.load(args.init_model, map_location='cpu')
|
||||
|
@ -172,24 +160,14 @@ def init_model(args, device, n_gpu):
|
|||
|
||||
# Prepare model
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
||||
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
|
||||
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
|
||||
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
||||
|
||||
model.to(device)
|
||||
|
||||
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
|
||||
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
|
||||
# remove the need for this code, but it is still valid.
|
||||
if args.fp16:
|
||||
try:
|
||||
import apex
|
||||
apex.amp.register_half_function(torch, 'einsum')
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
return model
|
||||
|
||||
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, coef_lr=1.0):
|
||||
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.):
|
||||
|
||||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
|
@ -212,72 +190,49 @@ def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, coe
|
|||
{'params': [p for n, p in decay_bert_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
|
||||
{'params': [p for n, p in decay_nobert_param_tp], 'weight_decay': 0.0}
|
||||
]
|
||||
|
||||
scheduler = None
|
||||
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
|
||||
schedule='warmup_linear', t_total=num_train_optimization_steps, weight_decay=0.01,
|
||||
max_grad_norm=1.0)
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
|
||||
output_device=local_rank, find_unused_parameters=True)
|
||||
|
||||
return optimizer, scheduler, model
|
||||
|
||||
def dataloader_youcook_train(args, tokenizer):
|
||||
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
|
||||
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
|
||||
logger.info('Done, caption length: {}'.format(len(caption)))
|
||||
|
||||
args.n_pair = 1
|
||||
youcook_dataset = Youcook_NoPair_DataLoader(
|
||||
csv=args.youcook_train_csv,
|
||||
features_path=args.youcook_features_path_2D,
|
||||
features_path_3D=args.youcook_features_path_3D,
|
||||
caption=caption,
|
||||
min_time=args.min_time,
|
||||
youcook_dataset = Youcook_DataLoader(
|
||||
csv=args.train_csv,
|
||||
data_path=args.data_path,
|
||||
features_path=args.features_path,
|
||||
max_words=args.max_words,
|
||||
min_words=args.min_words,
|
||||
feature_framerate=args.feature_framerate,
|
||||
tokenizer=tokenizer,
|
||||
n_pair=args.n_pair,
|
||||
pad_token=args.pad_token,
|
||||
max_frames=args.max_frames,
|
||||
)
|
||||
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(youcook_dataset)
|
||||
dataloader = DataLoader(
|
||||
youcook_dataset,
|
||||
batch_size=args.batch_size,
|
||||
batch_size=args.batch_size // args.n_gpu,
|
||||
num_workers=args.num_thread_reader,
|
||||
shuffle=True,
|
||||
pin_memory=False,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
return dataloader, len(youcook_dataset)
|
||||
return dataloader, len(youcook_dataset), train_sampler
|
||||
|
||||
def dataloader_youcook_test(args, tokenizer):
|
||||
logger.info('Loading captions: {}'.format(args.youcook_caption_path))
|
||||
caption = pickle.load(open(args.youcook_caption_path, 'rb'))
|
||||
logger.info('Done, caption length: {}'.format(len(caption)))
|
||||
|
||||
youcook_testset = Youcook_NoPair_DataLoader(
|
||||
csv=args.youcook_val_csv,
|
||||
features_path=args.youcook_features_path_2D,
|
||||
features_path_3D=args.youcook_features_path_3D,
|
||||
caption=caption,
|
||||
min_time=args.min_time,
|
||||
youcook_testset = Youcook_DataLoader(
|
||||
csv=args.val_csv,
|
||||
data_path=args.data_path,
|
||||
features_path=args.features_path,
|
||||
max_words=args.max_words,
|
||||
min_words=args.min_words,
|
||||
feature_framerate=args.feature_framerate,
|
||||
tokenizer=tokenizer,
|
||||
n_pair=-1,
|
||||
pad_token=args.pad_token,
|
||||
max_frames=args.max_frames,
|
||||
)
|
||||
|
||||
|
@ -293,6 +248,49 @@ def dataloader_youcook_test(args, tokenizer):
|
|||
|
||||
return dataloader_youcook, len(youcook_testset)
|
||||
|
||||
def dataloader_msrvtt_train(args, tokenizer):
|
||||
msrvtt_dataset = MSRVTT_TrainDataLoader(
|
||||
csv_path=args.train_csv,
|
||||
json_path=args.data_path,
|
||||
features_path=args.features_path,
|
||||
max_words=args.max_words,
|
||||
feature_framerate=args.feature_framerate,
|
||||
tokenizer=tokenizer,
|
||||
max_frames=args.max_frames,
|
||||
unfold_sentences=args.expand_msrvtt_sentences,
|
||||
)
|
||||
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset)
|
||||
dataloader = DataLoader(
|
||||
msrvtt_dataset,
|
||||
batch_size=args.batch_size // args.n_gpu,
|
||||
num_workers=args.num_thread_reader,
|
||||
pin_memory=False,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
return dataloader, len(msrvtt_dataset), train_sampler
|
||||
|
||||
def dataloader_msrvtt_test(args, tokenizer):
|
||||
msrvtt_testset = MSRVTT_DataLoader(
|
||||
csv_path=args.val_csv,
|
||||
features_path=args.features_path,
|
||||
max_words=args.max_words,
|
||||
feature_framerate=args.feature_framerate,
|
||||
tokenizer=tokenizer,
|
||||
max_frames=args.max_frames,
|
||||
)
|
||||
dataloader_msrvtt = DataLoader(
|
||||
msrvtt_testset,
|
||||
batch_size=args.batch_size_val,
|
||||
num_workers=args.num_thread_reader,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
)
|
||||
return dataloader_msrvtt, len(msrvtt_testset)
|
||||
|
||||
def save_model(epoch, args, model, type_name=""):
|
||||
# Only save the model it-self
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
|
@ -307,10 +305,11 @@ def load_model(epoch, args, n_gpu, device, model_file=None):
|
|||
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
|
||||
if os.path.exists(model_file):
|
||||
model_state_dict = torch.load(model_file, map_location='cpu')
|
||||
logger.info("Model loaded from %s", model_file)
|
||||
if args.local_rank == 0:
|
||||
logger.info("Model loaded from %s", model_file)
|
||||
# Prepare model
|
||||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
|
||||
model = VLBert.from_pretrained(args.bert_model, args.visual_bert_model, args.cross_bert_model, args.decoder_bert_model,
|
||||
model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model,
|
||||
cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
|
||||
|
||||
model.to(device)
|
||||
|
@ -318,24 +317,18 @@ def load_model(epoch, args, n_gpu, device, model_file=None):
|
|||
model = None
|
||||
return model
|
||||
|
||||
def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step):
|
||||
def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=0):
|
||||
global logger
|
||||
torch.cuda.empty_cache()
|
||||
model.train()
|
||||
log_step = 100
|
||||
log_step = args.n_display
|
||||
start_time = time.time()
|
||||
total_loss = 0
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if n_gpu == 1:
|
||||
# multi-gpu does scattering it-self
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
batch = tuple(t.to(device=device, non_blocking=True) for t in batch)
|
||||
|
||||
input_ids, input_mask, segment_ids, video, video_mask, \
|
||||
pairs_masked_text, pairs_token_labels, masked_video, video_labels_index = batch
|
||||
|
@ -348,19 +341,12 @@ def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
|
|||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
loss.backward()
|
||||
|
||||
total_loss += float(loss)
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step() # Update learning rate schedule
|
||||
|
@ -369,10 +355,11 @@ def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
|
|||
optimizer.zero_grad()
|
||||
|
||||
global_step += 1
|
||||
if global_step % log_step == 0:
|
||||
if global_step % log_step == 0 and local_rank == 0:
|
||||
logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
|
||||
args.epochs, step + 1,
|
||||
len(train_dataloader), "-".join([str('%.6f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]), loss,
|
||||
len(train_dataloader), "-".join([str('%.6f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]),
|
||||
float(loss) * args.gradient_accumulation_steps,
|
||||
(time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -418,7 +405,6 @@ def eval_epoch(args, model, test_dataloader, device, n_gpu):
|
|||
|
||||
print("{}/{}\r".format(bid, len(test_dataloader)), end="")
|
||||
|
||||
start_time = time.time()
|
||||
if n_gpu > 1:
|
||||
device_ids = list(range(n_gpu))
|
||||
batch_list_t_splits = []
|
||||
|
@ -455,7 +441,6 @@ def eval_epoch(args, model, test_dataloader, device, n_gpu):
|
|||
sim_matrix += parallel_outputs[idx]
|
||||
else:
|
||||
sim_matrix = _run_on_single_gpu(model, batch_list, batch_list, batch_sequence_output_list, batch_visual_output_list)
|
||||
logger.info("%f" % (time.time() - start_time))
|
||||
|
||||
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
|
||||
metrics = compute_metrics(sim_matrix)
|
||||
|
@ -468,62 +453,65 @@ def eval_epoch(args, model, test_dataloader, device, n_gpu):
|
|||
|
||||
DATALOADER_DICT = {}
|
||||
DATALOADER_DICT["youcook"] = {"train":dataloader_youcook_train, "val":dataloader_youcook_test}
|
||||
DATALOADER_DICT["msrvtt"] = {"train":dataloader_msrvtt_train, "val":dataloader_msrvtt_test}
|
||||
|
||||
def main():
|
||||
global logger
|
||||
args = get_args()
|
||||
set_seed_logger(args)
|
||||
device, n_gpu = init_device(args)
|
||||
args = set_seed_logger(args)
|
||||
device, n_gpu = init_device(args, args.local_rank)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
|
||||
assert args.task_type == "retrieval"
|
||||
model = init_model(args, device, n_gpu, args.local_rank)
|
||||
|
||||
assert args.datatype in DATALOADER_DICT
|
||||
test_dataloader, test_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer)
|
||||
if args.local_rank == 0:
|
||||
logger.info("***** Running test *****")
|
||||
logger.info(" Num examples = %d", test_length)
|
||||
logger.info(" Batch size = %d", args.batch_size_val)
|
||||
logger.info(" Num steps = %d", len(test_dataloader))
|
||||
|
||||
if args.do_train:
|
||||
args.task_type = "retrieval"
|
||||
model = init_model(args, device, n_gpu)
|
||||
|
||||
datatype = args.datatype
|
||||
assert datatype in DATALOADER_DICT
|
||||
test_dataloader, test_length = DATALOADER_DICT[datatype]["val"](args, tokenizer)
|
||||
logger.info("***** Running test *****")
|
||||
logger.info(" Num examples = %d", test_length)
|
||||
logger.info(" Batch size = %d", args.batch_size_val)
|
||||
logger.info(" Num steps = %d", len(test_dataloader))
|
||||
|
||||
if args.do_pretrain:
|
||||
raise NotImplementedError("Deleted~ no use")
|
||||
elif args.do_train:
|
||||
|
||||
train_dataloader, train_length = DATALOADER_DICT[datatype]["train"](args, tokenizer)
|
||||
train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer)
|
||||
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
|
||||
/ args.gradient_accumulation_steps) * args.epochs
|
||||
|
||||
coef_lr = 0.1
|
||||
coef_lr = args.coef_lr
|
||||
if args.init_model:
|
||||
coef_lr = 1.0
|
||||
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, coef_lr=coef_lr)
|
||||
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", train_length)
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||
if args.local_rank == 0:
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", train_length)
|
||||
logger.info(" Batch size = %d", args.batch_size)
|
||||
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
|
||||
|
||||
best_score = 0.00001
|
||||
best_output_model_file = None
|
||||
global_step = 0
|
||||
for epoch in range(args.epochs):
|
||||
train_sampler.set_epoch(epoch)
|
||||
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
|
||||
scheduler, global_step)
|
||||
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
|
||||
output_model_file = save_model(epoch, args, model, type_name="")
|
||||
scheduler, global_step, local_rank=args.local_rank)
|
||||
if args.local_rank == 0:
|
||||
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
|
||||
output_model_file = save_model(epoch, args, model, type_name="")
|
||||
|
||||
R1 = eval_epoch(args, model, test_dataloader, device, n_gpu)
|
||||
if best_score <= R1:
|
||||
best_score = R1
|
||||
best_output_model_file = output_model_file
|
||||
logger.info("The best model is: {}, the R1 is: {:.4f}".format(best_output_model_file, best_score))
|
||||
model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
|
||||
eval_epoch(args, model, test_dataloader, device, n_gpu)
|
||||
R1 = eval_epoch(args, model, test_dataloader, device, n_gpu)
|
||||
if best_score <= R1:
|
||||
best_score = R1
|
||||
best_output_model_file = output_model_file
|
||||
logger.info("The best model is: {}, the R1 is: {:.4f}".format(best_output_model_file, best_score))
|
||||
if args.local_rank == 0:
|
||||
model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
|
||||
eval_epoch(args, model, test_dataloader, device, n_gpu)
|
||||
elif args.do_eval:
|
||||
eval_epoch(args, model, test_dataloader, device, n_gpu)
|
||||
if args.local_rank == 0:
|
||||
eval_epoch(args, model, test_dataloader, device, n_gpu)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -19,7 +19,6 @@ def compute_metrics(x):
|
|||
metrics['MR'] = np.median(ind) + 1
|
||||
return metrics
|
||||
|
||||
|
||||
def print_computed_metrics(metrics):
|
||||
r1 = metrics['R1']
|
||||
r5 = metrics['R5']
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
""" Manage beam search info structure.
|
||||
|
||||
Heavily borrowed from OpenNMT-py.
|
||||
For code in OpenNMT-py, please check the following link:
|
||||
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py
|
||||
"""
|
||||
Manage beam search info structure.
|
||||
Heavily borrowed from OpenNMT-py.
|
||||
For code in OpenNMT-py, please check the following link (maybe in oldest version):
|
||||
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class Constants():
|
||||
def __init__(self):
|
||||
|
@ -14,7 +13,7 @@ class Constants():
|
|||
self.UNK = 1
|
||||
self.BOS = 2
|
||||
self.EOS = 3
|
||||
self.PAD_WORD = '[MASK]'
|
||||
self.PAD_WORD = '[PAD]'
|
||||
self.UNK_WORD = '[UNK]'
|
||||
self.BOS_WORD = '[CLS]'
|
||||
self.EOS_WORD = '[SEP]'
|
||||
|
@ -76,7 +75,7 @@ class Beam():
|
|||
self.scores = best_scores
|
||||
# bestScoresId is flattened as a (beam x word) array,
|
||||
# so we need to calculate which word and beam each score came from
|
||||
prev_k = best_scores_id / num_words
|
||||
prev_k = best_scores_id // num_words
|
||||
self.prev_ks.append(prev_k)
|
||||
self.next_ys.append(best_scores_id - prev_k * num_words)
|
||||
# End condition is when top-of-beam is EOS.
|
|
@ -19,14 +19,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
@ -34,38 +27,22 @@ from torch import nn
|
|||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from pytorch_pretrained_bert.bert_module import PreTrainedBertModel, BertModel, BertLayerNorm, BertOnlyMLMHead
|
||||
from pytorch_pretrained_bert.visual_module import PreTrainedVisualBertModel, VisualBertModel, VisualBertLayerNorm, VisualBertOnlyMLMHead
|
||||
from pytorch_pretrained_bert.cross_module import PreTrainedCrossBertModel, CrossBertModel, CrossBertLayerNorm
|
||||
from pytorch_pretrained_bert.decoder_module import PreTrainedDecoderBertModel, DecoderBertModel, DecoderBertLayerNorm
|
||||
from modules.until_module import PreTrainedModel, LayerNorm, CrossEn, MILNCELoss, MaxMarginRankingLoss
|
||||
from modules.module_bert import BertModel, BertConfig, BertOnlyMLMHead
|
||||
from modules.module_visual import VisualModel, VisualConfig, VisualOnlyMLMHead
|
||||
from modules.module_cross import CrossModel, CrossConfig
|
||||
from modules.module_decoder import DecoderModel, DecoderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
|
||||
except ImportError:
|
||||
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(LayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
class PreTrainedModel(nn.Module):
|
||||
class UniVLPreTrainedModel(PreTrainedModel, nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, bert_config, visual_config, cross_config, decoder_config, *inputs, **kwargs):
|
||||
super(PreTrainedModel, self).__init__()
|
||||
# utilize bert config as base config
|
||||
super(UniVLPreTrainedModel, self).__init__(bert_config)
|
||||
self.bert_config = bert_config
|
||||
self.visual_config = visual_config
|
||||
self.cross_config = cross_config
|
||||
|
@ -76,88 +53,8 @@ class PreTrainedModel(nn.Module):
|
|||
self.cross = None
|
||||
self.decoder = None
|
||||
|
||||
def init_bert_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.bert_config.initializer_range)
|
||||
elif isinstance(module, BertLayerNorm) \
|
||||
or isinstance(module, LayerNorm) \
|
||||
or isinstance(module, VisualBertLayerNorm) \
|
||||
or isinstance(module, CrossBertLayerNorm) \
|
||||
or isinstance(module, DecoderBertLayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
raise ValueError("`TODO: Just bert has this function, call it at here.`")
|
||||
|
||||
@classmethod
|
||||
def init_preweight(cls, model, state_dict, prefix=None, task_config=None):
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
if prefix is not None:
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
old_keys.append(key)
|
||||
new_keys.append(prefix+key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(model, prefix='')
|
||||
|
||||
if prefix is None and (task_config is None or task_config.local_rank == 0):
|
||||
logger.info("-"*20)
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, pretrained_visual_model_name, pretrained_cross_model_name,
|
||||
pretrained_decoder_model_name,
|
||||
def from_pretrained(cls, pretrained_bert_name, visual_model_name, cross_model_name, decoder_model_name,
|
||||
state_dict=None, cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
|
||||
|
||||
task_config = None
|
||||
|
@ -168,94 +65,21 @@ class PreTrainedModel(nn.Module):
|
|||
elif task_config.local_rank == -1:
|
||||
task_config.local_rank = 0
|
||||
|
||||
bert_config, state_dict = PreTrainedBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=task_config)
|
||||
# visual_state_dict is no use here
|
||||
visual_config, _ = PreTrainedVisualBertModel.get_config(pretrained_visual_model_name, cache_dir, state_dict=None, task_config=task_config)
|
||||
# visual_state_dict is no use here
|
||||
cross_config, _ = PreTrainedCrossBertModel.get_config(pretrained_cross_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
|
||||
decoder_config, _ = PreTrainedDecoderBertModel.get_config(pretrained_decoder_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
|
||||
# Instantiate model.
|
||||
bert_config, state_dict = BertConfig.get_config(pretrained_bert_name, cache_dir, type_vocab_size, state_dict, task_config=task_config)
|
||||
visual_config, _ = VisualConfig.get_config(visual_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
|
||||
cross_config, _ = CrossConfig.get_config(cross_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
|
||||
decoder_config, _ = DecoderConfig.get_config(decoder_model_name, cache_dir, type_vocab_size, state_dict=None, task_config=task_config)
|
||||
|
||||
model = cls(bert_config, visual_config, cross_config, decoder_config, *inputs, **kwargs)
|
||||
|
||||
assert model.bert is not None
|
||||
assert model.visual is not None
|
||||
# assert model.cross is not None
|
||||
# assert model.decoder is not None
|
||||
|
||||
if state_dict is not None:
|
||||
model = cls.init_preweight(model, state_dict, task_config=task_config)
|
||||
|
||||
return model
|
||||
|
||||
class CrossEn(nn.Module):
|
||||
def __init__(self,):
|
||||
super(CrossEn, self).__init__()
|
||||
|
||||
def forward(self, sim_matrix):
|
||||
logpt = F.log_softmax(sim_matrix, dim=-1)
|
||||
logpt = torch.diag(logpt)
|
||||
nce_loss = -logpt
|
||||
sim_loss = nce_loss.mean()
|
||||
return sim_loss
|
||||
|
||||
class MILNCELoss(nn.Module):
|
||||
def __init__(self, batch_size=1, n_pair=1,):
|
||||
super(MILNCELoss, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.n_pair = n_pair
|
||||
|
||||
def forward(self, sim_matrix):
|
||||
mm_mask = np.eye(self.batch_size)
|
||||
mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair)))
|
||||
mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device)
|
||||
|
||||
from_text_matrix = sim_matrix + mm_mask * -1e12
|
||||
from_video_matrix = sim_matrix.transpose(1, 0) # video * text
|
||||
|
||||
new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1)
|
||||
logpt = F.log_softmax(new_sim_matrix, dim=-1)
|
||||
|
||||
mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1)
|
||||
masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12
|
||||
|
||||
new_logpt = -torch.logsumexp(masked_logpt, dim=-1)
|
||||
|
||||
logpt_choice = torch.zeros_like(new_logpt)
|
||||
mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2)
|
||||
logpt_choice[mark_ind] = 1
|
||||
sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=torch.uint8)).mean()
|
||||
return sim_loss
|
||||
|
||||
class MaxMarginRankingLoss(nn.Module):
|
||||
def __init__(self,
|
||||
margin=1.0,
|
||||
negative_weighting=False,
|
||||
batch_size=1,
|
||||
n_pair=1,
|
||||
hard_negative_rate=0.5,
|
||||
):
|
||||
super(MaxMarginRankingLoss, self).__init__()
|
||||
self.margin = margin
|
||||
self.n_pair = n_pair
|
||||
self.batch_size = batch_size
|
||||
easy_negative_rate = 1 - hard_negative_rate
|
||||
self.easy_negative_rate = easy_negative_rate
|
||||
self.negative_weighting = negative_weighting
|
||||
if n_pair > 1 and batch_size > 1:
|
||||
alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate))
|
||||
mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha
|
||||
mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair)))
|
||||
mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate))
|
||||
self.mm_mask = mm_mask.float()
|
||||
|
||||
def forward(self, x):
|
||||
d = torch.diag(x)
|
||||
max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \
|
||||
F.relu(self.margin + x - d.view(1, -1))
|
||||
if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1:
|
||||
max_margin = max_margin * self.mm_mask.to(max_margin.device)
|
||||
return max_margin.mean()
|
||||
|
||||
class NormalizeVideo(nn.Module):
|
||||
def __init__(self, task_config):
|
||||
super(NormalizeVideo, self).__init__()
|
||||
|
@ -279,9 +103,12 @@ def update_attr(target_name, target_config, target_attr_name, source_config, sou
|
|||
target_attr_name, getattr(target_config, target_attr_name)))
|
||||
return target_config
|
||||
|
||||
class VLBert(PreTrainedModel):
|
||||
def check_attr(target_name, task_config):
|
||||
return hasattr(task_config, target_name) and task_config.__dict__[target_name]
|
||||
|
||||
class UniVL(UniVLPreTrainedModel):
|
||||
def __init__(self, bert_config, visual_config, cross_config, decoder_config, task_config):
|
||||
super(VLBert, self).__init__(bert_config, visual_config, cross_config, decoder_config)
|
||||
super(UniVL, self).__init__(bert_config, visual_config, cross_config, decoder_config)
|
||||
self.task_config = task_config
|
||||
self.ignore_video_index = -1
|
||||
|
||||
|
@ -290,57 +117,54 @@ class VLBert(PreTrainedModel):
|
|||
assert self.task_config.max_frames <= visual_config.max_position_embeddings
|
||||
assert self.task_config.max_words + self.task_config.max_frames <= cross_config.max_position_embeddings
|
||||
|
||||
self._choice_sim = True
|
||||
self._cross_model = False
|
||||
self._stage_one = True
|
||||
self._stage_two = False
|
||||
|
||||
if hasattr(self.task_config, 'cross_model') and self.task_config.cross_model:
|
||||
self._choice_sim = False
|
||||
self._cross_model = self.task_config.cross_model
|
||||
if check_attr('stage_two', self.task_config):
|
||||
self._stage_one = False
|
||||
self._stage_two = self.task_config.stage_two
|
||||
show_log(task_config, "Stage-One:{}, Stage-Two:{}".format(self._stage_one, self._stage_two))
|
||||
|
||||
self.train_sim_after_cross = False
|
||||
if self._choice_sim and hasattr(self.task_config, 'train_sim_after_cross') and self.task_config.train_sim_after_cross:
|
||||
if self._stage_one and check_attr('train_sim_after_cross', self.task_config):
|
||||
self.train_sim_after_cross = True
|
||||
show_log(task_config, "Test retrieval after cross encoder.")
|
||||
|
||||
self.with_sim_in_decoder = True
|
||||
if hasattr(self.task_config, 'without_sim_in_decoder') and self.task_config.without_sim_in_decoder:
|
||||
self.with_sim_in_decoder = False
|
||||
show_log(task_config, "Set no sim after cross.")
|
||||
|
||||
self.pretrain_without_decoder = False
|
||||
if hasattr(self.task_config, 'pretrain_without_decoder') and self.task_config.pretrain_without_decoder:
|
||||
self.pretrain_without_decoder = True
|
||||
show_log(task_config, "Set pretrain without decoder.")
|
||||
|
||||
show_log(task_config, "sim:{}, cross_model:{}".format(self._choice_sim, self._cross_model))
|
||||
|
||||
# The name should be `bert`, if dynamic vocabulary is needed
|
||||
# Text Encoder ===>
|
||||
bert_config = update_attr("bert_config", bert_config, "num_hidden_layers",
|
||||
self.task_config, "text_num_hidden_layers")
|
||||
self.bert = BertModel(bert_config)
|
||||
bert_word_embeddings_weight = self.bert.embeddings.word_embeddings.weight
|
||||
bert_position_embeddings_weight = self.bert.embeddings.position_embeddings.weight
|
||||
# <=== End of Text Encoder
|
||||
|
||||
# Video Encoder ===>
|
||||
visual_config = update_attr("visual_config", visual_config, "num_hidden_layers",
|
||||
self.task_config, "visual_num_hidden_layers")
|
||||
self.visual = VisualBertModel(visual_config)
|
||||
self.visual = VisualModel(visual_config)
|
||||
visual_word_embeddings_weight = self.visual.embeddings.word_embeddings.weight
|
||||
# <=== End of Video Encoder
|
||||
|
||||
if self._choice_sim is False or self.train_sim_after_cross:
|
||||
self.cls = BertOnlyMLMHead(bert_config, bert_word_embeddings_weight)
|
||||
self.cls_visual = VisualBertOnlyMLMHead(visual_config, visual_word_embeddings_weight)
|
||||
|
||||
if self._stage_one is False or self.train_sim_after_cross:
|
||||
# Cross Encoder ===>
|
||||
cross_config = update_attr("cross_config", cross_config, "num_hidden_layers",
|
||||
self.task_config, "cross_num_hidden_layers")
|
||||
self.cross = CrossBertModel(cross_config)
|
||||
self.cross = CrossModel(cross_config)
|
||||
# <=== End of Cross Encoder
|
||||
|
||||
if self.train_sim_after_cross is False:
|
||||
if self.task_config.do_pretrain and self.pretrain_without_decoder:
|
||||
show_log(task_config, "Pretrain without decoder.")
|
||||
else:
|
||||
decoder_config = update_attr("decoder_config", decoder_config, "num_decoder_layers",
|
||||
self.task_config, "decoder_num_hidden_layers")
|
||||
self.decoder = DecoderBertModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight)
|
||||
# Decoder ===>
|
||||
decoder_config = update_attr("decoder_config", decoder_config, "num_decoder_layers",
|
||||
self.task_config, "decoder_num_hidden_layers")
|
||||
self.decoder = DecoderModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight)
|
||||
# <=== End of Decoder
|
||||
|
||||
self.cls = BertOnlyMLMHead(bert_config, bert_word_embeddings_weight)
|
||||
self.cls_visual = VisualOnlyMLMHead(visual_config, visual_word_embeddings_weight)
|
||||
|
||||
self.similarity_dense = nn.Linear(bert_config.hidden_size, 1)
|
||||
self.alm_loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
self.decoder_loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
|
||||
self.normalize_video = NormalizeVideo(task_config)
|
||||
|
||||
|
@ -350,36 +174,15 @@ class VLBert(PreTrainedModel):
|
|||
batch_size=task_config.batch_size // task_config.n_gpu,
|
||||
n_pair=task_config.n_pair,
|
||||
hard_negative_rate=task_config.hard_negative_rate, )
|
||||
|
||||
if task_config.use_mil:
|
||||
self.loss_fct = mILNCELoss
|
||||
show_log(task_config, "Using MILNCE.")
|
||||
self.loss_fct = CrossEn() if self._stage_two else mILNCELoss
|
||||
self._pretrain_sim_loss_fct = mILNCELoss
|
||||
else:
|
||||
self.loss_fct = maxMarginRankingLoss
|
||||
show_log(task_config, "Using Ranking Loss.")
|
||||
self.loss_fct = CrossEn() if self._stage_two else maxMarginRankingLoss
|
||||
self._pretrain_sim_loss_fct = maxMarginRankingLoss
|
||||
|
||||
if self._cross_model:
|
||||
self.loss_fct = CrossEn()
|
||||
show_log(task_config, "Use CE Loss.")
|
||||
|
||||
if self._choice_sim is False or self.train_sim_after_cross:
|
||||
self.similarity_dense = nn.Linear(bert_config.hidden_size, 1)
|
||||
|
||||
self.alm_loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
self.decoder_loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
|
||||
self.pretrain_with_joint_sim = False
|
||||
if hasattr(self.task_config, 'pretrain_with_joint_sim'):
|
||||
if self.task_config.pretrain_with_joint_sim:
|
||||
self.pretrain_with_joint_sim = True
|
||||
show_log(task_config, "Set pretrain with joint embedding.")
|
||||
if task_config.use_mil:
|
||||
self._pretrain_sim_loss_fct = mILNCELoss
|
||||
show_log(task_config, "Pretrain with joint embedding using MILNCE.")
|
||||
else:
|
||||
self._pretrain_sim_loss_fct = maxMarginRankingLoss
|
||||
show_log(task_config, "Pretrain with joint embedding using Ranking Loss.")
|
||||
|
||||
self.apply(self.init_bert_weights)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask, video, video_mask=None,
|
||||
pairs_masked_text=None, pairs_token_labels=None, masked_video=None, video_labels_index=None,
|
||||
|
@ -395,17 +198,18 @@ class VLBert(PreTrainedModel):
|
|||
input_caption_ids = input_caption_ids.view(-1, input_caption_ids.shape[-1])
|
||||
decoder_mask = decoder_mask.view(-1, decoder_mask.shape[-1])
|
||||
|
||||
sequence_output, visual_output = self.get_sequence_visual_output(input_ids, token_type_ids,
|
||||
attention_mask, video, video_mask, shaped=True)
|
||||
sequence_output, visual_output = self.get_sequence_visual_output(input_ids, token_type_ids, attention_mask,
|
||||
video, video_mask, shaped=True)
|
||||
|
||||
if self.training:
|
||||
loss = 0.
|
||||
if self._choice_sim:
|
||||
sim_matrix = self.get_similarity_logits(sequence_output, visual_output, attention_mask, video_mask, shaped=True)
|
||||
if self._stage_one:
|
||||
sim_matrix = self.get_similarity_logits(sequence_output, visual_output, attention_mask,
|
||||
video_mask, shaped=True)
|
||||
sim_loss = self.loss_fct(sim_matrix)
|
||||
loss = loss + sim_loss
|
||||
loss += sim_loss
|
||||
|
||||
if self._cross_model and pairs_masked_text is not None and pairs_token_labels is not None:
|
||||
if self._stage_two and pairs_masked_text is not None and pairs_token_labels is not None:
|
||||
|
||||
if self.task_config.do_pretrain:
|
||||
pairs_masked_text = pairs_masked_text.view(-1, pairs_masked_text.shape[-1])
|
||||
|
@ -414,25 +218,25 @@ class VLBert(PreTrainedModel):
|
|||
masked_video = self.normalize_video(masked_video)
|
||||
video_labels_index = video_labels_index.view(-1, video_labels_index.shape[-1])
|
||||
|
||||
sequence_output_alm, visual_output_alm = self.get_sequence_visual_output(pairs_masked_text, token_type_ids, attention_mask, masked_video, video_mask, shaped=True)
|
||||
sequence_output_alm, visual_output_alm = self.get_sequence_visual_output(pairs_masked_text, token_type_ids,
|
||||
attention_mask, masked_video, video_mask, shaped=True)
|
||||
|
||||
cross_output, pooled_output, concat_mask = self._get_cross_output(sequence_output_alm, visual_output_alm, attention_mask, video_mask)
|
||||
sequence_cross_output, visual_cross_output = torch.split(cross_output, [attention_mask.size(-1), video_mask.size(-1)], dim=1)
|
||||
|
||||
alm_loss = self._calculate_mlm_loss(sequence_cross_output, pairs_token_labels) # For mlm
|
||||
loss = loss + alm_loss
|
||||
alm_loss = self._calculate_mlm_loss(sequence_cross_output, pairs_token_labels)
|
||||
loss += alm_loss
|
||||
|
||||
nce_loss = self._calculate_mfm_loss(visual_cross_output, video, video_mask, video_labels_index) # For mfm
|
||||
nce_loss = self._calculate_mfm_loss(visual_cross_output, video, video_mask, video_labels_index)
|
||||
loss += nce_loss
|
||||
|
||||
if self.pretrain_with_joint_sim:
|
||||
sim_matrix = self.get_similarity_logits(sequence_output, visual_output, attention_mask, video_mask,
|
||||
shaped=True, _pretrain_joint=True)
|
||||
sim_loss_joint = self._pretrain_sim_loss_fct(sim_matrix)
|
||||
loss = loss + sim_loss_joint
|
||||
sim_matrix = self.get_similarity_logits(sequence_output, visual_output, attention_mask, video_mask,
|
||||
shaped=True, _pretrain_joint=True)
|
||||
sim_loss_joint = self._pretrain_sim_loss_fct(sim_matrix)
|
||||
loss += sim_loss_joint
|
||||
|
||||
if (input_caption_ids is not None) and \
|
||||
((self.task_config.do_pretrain and self.pretrain_without_decoder is False)
|
||||
(self.task_config.do_pretrain
|
||||
or (self.task_config.do_pretrain is False and self.task_config.task_type == "caption")):
|
||||
if self.task_config.do_pretrain:
|
||||
decoder_scores, res_tuples = self._get_decoder_score(sequence_output_alm, visual_output_alm,
|
||||
|
@ -442,21 +246,25 @@ class VLBert(PreTrainedModel):
|
|||
decoder_scores, res_tuples = self._get_decoder_score(sequence_output, visual_output,
|
||||
input_ids, attention_mask, video_mask,
|
||||
input_caption_ids, decoder_mask, shaped=True)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
output_caption_ids = output_caption_ids.view(-1, output_caption_ids.shape[-1])
|
||||
decoder_loss = self.decoder_loss_fct(decoder_scores.view(-1, self.bert_config.vocab_size), output_caption_ids.view(-1))
|
||||
loss = loss + decoder_loss
|
||||
loss += decoder_loss
|
||||
|
||||
if (self.with_sim_in_decoder and self.task_config.do_pretrain) or \
|
||||
self.task_config.task_type == "retrieval":
|
||||
if self.task_config.do_pretrain or self.task_config.task_type == "retrieval":
|
||||
if self.task_config.do_pretrain:
|
||||
sim_matrix_text_visual = self.get_similarity_logits(sequence_output_alm, visual_output_alm, attention_mask,
|
||||
video_mask, shaped=True)
|
||||
sim_matrix_text_visual = self.get_similarity_logits(sequence_output_alm, visual_output_alm,
|
||||
attention_mask, video_mask, shaped=True)
|
||||
elif self.task_config.task_type == "retrieval":
|
||||
sim_matrix_text_visual = self.get_similarity_logits(sequence_output, visual_output, attention_mask,
|
||||
video_mask, shaped=True)
|
||||
sim_matrix_text_visual = self.get_similarity_logits(sequence_output, visual_output,
|
||||
attention_mask, video_mask, shaped=True)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
sim_loss_text_visual = self.loss_fct(sim_matrix_text_visual)
|
||||
loss = loss + sim_loss_text_visual
|
||||
loss += sim_loss_text_visual
|
||||
|
||||
return loss
|
||||
else:
|
||||
|
@ -488,7 +296,6 @@ class VLBert(PreTrainedModel):
|
|||
nce_loss = nce_loss.mean()
|
||||
return nce_loss
|
||||
|
||||
# For inference
|
||||
def get_sequence_visual_output(self, input_ids, token_type_ids, attention_mask, video, video_mask, shaped=False):
|
||||
if shaped is False:
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
|
@ -497,17 +304,26 @@ class VLBert(PreTrainedModel):
|
|||
video_mask = video_mask.view(-1, video_mask.shape[-1])
|
||||
video = self.normalize_video(video)
|
||||
|
||||
encoded_layers, _ = self.bert("bi", input_ids, token_type_ids, attention_mask, output_all_encoded_layers=True)
|
||||
encoded_layers, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=True)
|
||||
sequence_output = encoded_layers[-1]
|
||||
|
||||
visual_layers, _ = self.visual("bi", video, video_mask, output_all_encoded_layers=True)
|
||||
visual_layers, _ = self.visual(video, video_mask, output_all_encoded_layers=True)
|
||||
visual_output = visual_layers[-1]
|
||||
|
||||
# cannot add any cross code here
|
||||
# beacause sequence and visual features are exclusive
|
||||
|
||||
return sequence_output, visual_output
|
||||
|
||||
def _get_cross_output(self, sequence_output, visual_output, attention_mask, video_mask):
|
||||
concat_features = torch.cat((sequence_output, visual_output), dim=1) # concatnate tokens and frames
|
||||
concat_mask = torch.cat((attention_mask, video_mask), dim=1)
|
||||
text_type_ = torch.zeros_like(attention_mask)
|
||||
video_type_ = torch.ones_like(video_mask)
|
||||
concat_type = torch.cat((text_type_, video_type_), dim=1)
|
||||
|
||||
cross_layers, pooled_output = self.cross(concat_features, concat_type, concat_mask, output_all_encoded_layers=True)
|
||||
cross_output = cross_layers[-1]
|
||||
|
||||
return cross_output, pooled_output, concat_mask
|
||||
|
||||
def _mean_pooling_for_similarity(self, sequence_output, visual_output, attention_mask, video_mask,):
|
||||
attention_mask_un = attention_mask.to(dtype=torch.float).unsqueeze(-1)
|
||||
attention_mask_un[:, 0, :] = 0.
|
||||
|
@ -558,13 +374,12 @@ class VLBert(PreTrainedModel):
|
|||
retrieve_logits = torch.cat(retrieve_logits_list, dim=0)
|
||||
return retrieve_logits
|
||||
|
||||
# For inference
|
||||
def get_similarity_logits(self, sequence_output, visual_output, attention_mask, video_mask, shaped=False, _pretrain_joint=False):
|
||||
if shaped is False:
|
||||
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
|
||||
video_mask = video_mask.view(-1, video_mask.shape[-1])
|
||||
|
||||
if self._cross_model and _pretrain_joint is False:
|
||||
if self._stage_two and _pretrain_joint is False:
|
||||
retrieve_logits = self._cross_similarity(sequence_output, visual_output, attention_mask, video_mask)
|
||||
else:
|
||||
if self.train_sim_after_cross:
|
||||
|
@ -579,19 +394,6 @@ class VLBert(PreTrainedModel):
|
|||
|
||||
return retrieve_logits
|
||||
|
||||
def _get_cross_output(self, sequence_output, visual_output, attention_mask, video_mask):
|
||||
# Generate one pair (x, y) ===>
|
||||
concat_features = torch.cat((sequence_output, visual_output), dim=1) # concatnate tokens and frames
|
||||
concat_mask = torch.cat((attention_mask, video_mask), dim=1)
|
||||
text_type_ = torch.zeros_like(attention_mask)
|
||||
video_type_ = torch.ones_like(video_mask)
|
||||
concat_type = torch.cat((text_type_, video_type_), dim=1)
|
||||
# <===
|
||||
cross_layers, pooled_output = self.cross("bi", concat_features, concat_type, concat_mask, output_all_encoded_layers=True)
|
||||
cross_output = cross_layers[-1]
|
||||
|
||||
return cross_output, pooled_output, concat_mask
|
||||
|
||||
def _get_decoder_score(self, sequence_output, visual_output, input_ids, attention_mask, video_mask, input_caption_ids, decoder_mask, shaped=False):
|
||||
|
||||
if shaped is False:
|
||||
|
@ -604,16 +406,12 @@ class VLBert(PreTrainedModel):
|
|||
|
||||
res_tuples = ()
|
||||
cross_output, pooled_output, concat_mask = self._get_cross_output(sequence_output, visual_output, attention_mask, video_mask)
|
||||
decoder_scores = self.decoder(input_caption_ids, input_ids, encoder_outs=cross_output, answer_mask=decoder_mask, encoder_mask=concat_mask)
|
||||
decoder_scores = self.decoder(input_caption_ids, encoder_outs=cross_output, answer_mask=decoder_mask, encoder_mask=concat_mask)
|
||||
|
||||
return decoder_scores, res_tuples
|
||||
|
||||
def decoder_caption(self, sequence_output, visual_output, input_ids, attention_mask, video_mask, input_caption_ids, decoder_mask,
|
||||
shaped=False, get_logits=False):
|
||||
"""
|
||||
:param get_logits: Used in Beam Search
|
||||
:return:
|
||||
"""
|
||||
if shaped is False:
|
||||
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
||||
attention_mask = attention_mask.view(-1, attention_mask.shape[-1])
|
||||
|
@ -622,9 +420,9 @@ class VLBert(PreTrainedModel):
|
|||
input_caption_ids = input_caption_ids.view(-1, input_caption_ids.shape[-1])
|
||||
decoder_mask = decoder_mask.view(-1, decoder_mask.shape[-1])
|
||||
|
||||
assert self.pretrain_without_decoder is False
|
||||
decoder_scores, _ = self._get_decoder_score(sequence_output, visual_output,
|
||||
input_ids, attention_mask, video_mask, input_caption_ids, decoder_mask, shaped=True)
|
||||
input_ids, attention_mask, video_mask,
|
||||
input_caption_ids, decoder_mask, shaped=True)
|
||||
|
||||
if get_logits:
|
||||
return decoder_scores
|
|
@ -27,14 +27,13 @@ import logging
|
|||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .until_config import PretrainedConfig
|
||||
from .until_module import PreTrainedModel, LayerNorm, ACT2FN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -51,24 +50,14 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||
CONFIG_NAME = 'bert_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
class BertConfig(object):
|
||||
class BertConfig(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `BertModel`.
|
||||
"""
|
||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
config_name = CONFIG_NAME
|
||||
weights_name = WEIGHTS_NAME
|
||||
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file,
|
||||
hidden_size=768,
|
||||
|
@ -126,52 +115,6 @@ class BertConfig(object):
|
|||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
||||
config = BertConfig(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `BertConfig` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
||||
except ImportError:
|
||||
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
||||
class BertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(BertLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
|
@ -183,7 +126,7 @@ class BertEmbeddings(nn.Module):
|
|||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None):
|
||||
|
@ -258,7 +201,7 @@ class BertSelfOutput(nn.Module):
|
|||
def __init__(self, config):
|
||||
super(BertSelfOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
|
@ -297,7 +240,7 @@ class BertOutput(nn.Module):
|
|||
def __init__(self, config):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
|
@ -359,7 +302,7 @@ class BertPredictionHeadTransform(nn.Module):
|
|||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
|
@ -418,212 +361,7 @@ class BertPreTrainingHeads(nn.Module):
|
|||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
class PreTrainedBertModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedBertModel, self).__init__()
|
||||
if not isinstance(config, BertConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
def init_bert_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, BertLayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
model = self.module if hasattr(self, 'module') else self
|
||||
base_model = getattr(model, "bert", model) # get the base model if needed
|
||||
|
||||
old_embeddings = base_model.embeddings.word_embeddings
|
||||
# get_resized_embeddings <-----
|
||||
if new_num_tokens is None:
|
||||
return old_embeddings
|
||||
|
||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
if old_num_tokens == new_num_tokens:
|
||||
return old_embeddings
|
||||
|
||||
# Build new embeddings
|
||||
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
||||
device = old_embeddings.weight.device
|
||||
new_embeddings.to(device)
|
||||
|
||||
# initialize all new embeddings (in particular added tokens)
|
||||
self.init_bert_weights(new_embeddings)
|
||||
|
||||
# Copy word embeddings from the previous weights
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
||||
# ----->
|
||||
|
||||
# Update base model and current model config
|
||||
base_model.embeddings.word_embeddings = new_embeddings
|
||||
base_model.config.vocab_size = new_num_tokens
|
||||
model.config.vocab_size = new_num_tokens
|
||||
model.cls = BertOnlyMLMHead(model.config, base_model.embeddings.word_embeddings.weight).to(device)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
|
||||
'''
|
||||
abstract from `from_pretrained`
|
||||
'''
|
||||
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
|
||||
if os.path.exists(archive_file) is False:
|
||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
||||
else:
|
||||
archive_file = pretrained_model_name
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list. "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
archive_file))
|
||||
return None
|
||||
if resolved_archive_file == archive_file:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {}".format(archive_file))
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
tempdir = None
|
||||
if os.path.isdir(resolved_archive_file):
|
||||
serialization_dir = resolved_archive_file
|
||||
else:
|
||||
# Extract archive to temp dir
|
||||
tempdir = tempfile.mkdtemp()
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("extracting archive file {} to temp dir {}".format(
|
||||
resolved_archive_file, tempdir))
|
||||
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||
archive.extractall(tempdir)
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
config = BertConfig.from_json_file(config_file)
|
||||
config.type_vocab_size = type_vocab_size
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Model config {}".format(config))
|
||||
|
||||
if state_dict is None:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_path):
|
||||
state_dict = torch.load(weights_path)
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Weight doesn't exsits. {}".format(weights_path))
|
||||
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
return config, state_dict
|
||||
|
||||
@classmethod
|
||||
def init_preweight(cls, model, state_dict):
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None,
|
||||
cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `bert-base-uncased`
|
||||
. `bert-large-uncased`
|
||||
. `bert-base-cased`
|
||||
. `bert-large-cased`
|
||||
. `bert-base-multilingual-uncased`
|
||||
. `bert-base-multilingual-cased`
|
||||
. `bert-base-chinese`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
(ex: num_classes for BertForSequenceClassification)
|
||||
"""
|
||||
config, state_dict = PreTrainedBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict)
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
return model
|
||||
model = PreTrainedBertModel.init_preweight(model, state_dict)
|
||||
|
||||
return model
|
||||
|
||||
class BertModel(PreTrainedBertModel):
|
||||
class BertModel(PreTrainedModel):
|
||||
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
|
||||
|
||||
Params:
|
||||
|
@ -673,12 +411,10 @@ class BertModel(PreTrainedBertModel):
|
|||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = BertEncoder(config)
|
||||
self.pooler = BertPooler(config)
|
||||
self.apply(self.init_bert_weights)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
|
||||
def forward(self, type, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True,
|
||||
attention_sentenceA_mask=None,
|
||||
attention_sentenceB_mask=None):
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
@ -697,7 +433,7 @@ class BertModel(PreTrainedBertModel):
|
|||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||
|
@ -708,4 +444,4 @@ class BertModel(PreTrainedBertModel):
|
|||
pooled_output = self.pooler(sequence_output)
|
||||
if not output_all_encoded_layers:
|
||||
encoded_layers = encoded_layers[-1]
|
||||
return encoded_layers, pooled_output
|
||||
return encoded_layers, pooled_output
|
|
@ -27,43 +27,27 @@ import logging
|
|||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .until_config import PretrainedConfig
|
||||
from .until_module import PreTrainedModel, LayerNorm, ACT2FN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'cross-bert-base': "[TODO]",
|
||||
'cross-bert-large': "[TODO]",
|
||||
}
|
||||
|
||||
CONFIG_NAME = 'cross_bert_config.json'
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {}
|
||||
CONFIG_NAME = 'cross_config.json'
|
||||
WEIGHTS_NAME = 'cross_pytorch_model.bin'
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
class CrossBertConfig(object):
|
||||
"""Configuration class to store the configuration of a `CrossBertModel`.
|
||||
|
||||
class CrossConfig(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `CrossModel`.
|
||||
"""
|
||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
config_name = CONFIG_NAME
|
||||
weights_name = WEIGHTS_NAME
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file,
|
||||
hidden_size=768,
|
||||
|
@ -76,10 +60,10 @@ class CrossBertConfig(object):
|
|||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02):
|
||||
"""Constructs CrossBertConfig.
|
||||
"""Constructs CrossConfig.
|
||||
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossBertModel`.
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
|
@ -96,7 +80,7 @@ class CrossBertConfig(object):
|
|||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`CrossBertModel`.
|
||||
`CrossModel`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
|
@ -121,64 +105,19 @@ class CrossBertConfig(object):
|
|||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `CrossBertConfig` from a Python dictionary of parameters."""
|
||||
config = CrossBertConfig(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `CrossBertConfig` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as CrossBertLayerNorm
|
||||
except ImportError:
|
||||
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
||||
class CrossBertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(CrossBertLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
class CrossBertEmbeddings(nn.Module):
|
||||
class CrossEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CrossBertEmbeddings, self).__init__()
|
||||
super(CrossEmbeddings, self).__init__()
|
||||
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, concat_embeddings, concat_type=None):
|
||||
|
@ -198,9 +137,9 @@ class CrossBertEmbeddings(nn.Module):
|
|||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
class CrossBertSelfAttention(nn.Module):
|
||||
class CrossSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertSelfAttention, self).__init__()
|
||||
super(CrossSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
|
@ -232,7 +171,7 @@ class CrossBertSelfAttention(nn.Module):
|
|||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in CrossBertModel forward() function)
|
||||
# Apply the attention mask is (precomputed for all layers in CrossModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
|
@ -249,11 +188,11 @@ class CrossBertSelfAttention(nn.Module):
|
|||
return context_layer
|
||||
|
||||
|
||||
class CrossBertSelfOutput(nn.Module):
|
||||
class CrossSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertSelfOutput, self).__init__()
|
||||
super(CrossSelfOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
|
@ -263,11 +202,11 @@ class CrossBertSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class CrossBertAttention(nn.Module):
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertAttention, self).__init__()
|
||||
self.self = CrossBertSelfAttention(config)
|
||||
self.output = CrossBertSelfOutput(config)
|
||||
super(CrossAttention, self).__init__()
|
||||
self.self = CrossSelfAttention(config)
|
||||
self.output = CrossSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, attention_mask):
|
||||
self_output = self.self(input_tensor, attention_mask)
|
||||
|
@ -275,9 +214,9 @@ class CrossBertAttention(nn.Module):
|
|||
return attention_output
|
||||
|
||||
|
||||
class CrossBertIntermediate(nn.Module):
|
||||
class CrossIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertIntermediate, self).__init__()
|
||||
super(CrossIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
|
@ -288,11 +227,11 @@ class CrossBertIntermediate(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class CrossBertOutput(nn.Module):
|
||||
class CrossOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertOutput, self).__init__()
|
||||
super(CrossOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
|
@ -302,12 +241,12 @@ class CrossBertOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class CrossBertLayer(nn.Module):
|
||||
class CrossLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertLayer, self).__init__()
|
||||
self.attention = CrossBertAttention(config)
|
||||
self.intermediate = CrossBertIntermediate(config)
|
||||
self.output = CrossBertOutput(config)
|
||||
super(CrossLayer, self).__init__()
|
||||
self.attention = CrossAttention(config)
|
||||
self.intermediate = CrossIntermediate(config)
|
||||
self.output = CrossOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
|
@ -316,10 +255,10 @@ class CrossBertLayer(nn.Module):
|
|||
return layer_output
|
||||
|
||||
|
||||
class CrossBertEncoder(nn.Module):
|
||||
class CrossEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertEncoder, self).__init__()
|
||||
layer = CrossBertLayer(config)
|
||||
super(CrossEncoder, self).__init__()
|
||||
layer = CrossLayer(config)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
||||
|
@ -333,9 +272,9 @@ class CrossBertEncoder(nn.Module):
|
|||
return all_encoder_layers
|
||||
|
||||
|
||||
class CrossBertPooler(nn.Module):
|
||||
class CrossPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertPooler, self).__init__()
|
||||
super(CrossPooler, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
|
@ -348,13 +287,13 @@ class CrossBertPooler(nn.Module):
|
|||
return pooled_output
|
||||
|
||||
|
||||
class CrossBertPredictionHeadTransform(nn.Module):
|
||||
class CrossPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertPredictionHeadTransform, self).__init__()
|
||||
super(CrossPredictionHeadTransform, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
self.LayerNorm = CrossBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
|
@ -363,18 +302,18 @@ class CrossBertPredictionHeadTransform(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class CrossBertLMPredictionHead(nn.Module):
|
||||
def __init__(self, config, cross_bert_model_embedding_weights):
|
||||
super(CrossBertLMPredictionHead, self).__init__()
|
||||
self.transform = CrossBertPredictionHeadTransform(config)
|
||||
class CrossLMPredictionHead(nn.Module):
|
||||
def __init__(self, config, cross_model_embedding_weights):
|
||||
super(CrossLMPredictionHead, self).__init__()
|
||||
self.transform = CrossPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(cross_bert_model_embedding_weights.size(1),
|
||||
cross_bert_model_embedding_weights.size(0),
|
||||
self.decoder = nn.Linear(cross_model_embedding_weights.size(1),
|
||||
cross_model_embedding_weights.size(0),
|
||||
bias=False)
|
||||
self.decoder.weight = cross_bert_model_embedding_weights
|
||||
self.bias = nn.Parameter(torch.zeros(cross_bert_model_embedding_weights.size(0)))
|
||||
self.decoder.weight = cross_model_embedding_weights
|
||||
self.bias = nn.Parameter(torch.zeros(cross_model_embedding_weights.size(0)))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
|
@ -382,19 +321,19 @@ class CrossBertLMPredictionHead(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class CrossBertOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config, cross_bert_model_embedding_weights):
|
||||
super(CrossBertOnlyMLMHead, self).__init__()
|
||||
self.predictions = CrossBertLMPredictionHead(config, cross_bert_model_embedding_weights)
|
||||
class CrossOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config, cross_model_embedding_weights):
|
||||
super(CrossOnlyMLMHead, self).__init__()
|
||||
self.predictions = CrossLMPredictionHead(config, cross_model_embedding_weights)
|
||||
|
||||
def forward(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class CrossBertOnlyNSPHead(nn.Module):
|
||||
class CrossOnlyNSPHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(CrossBertOnlyNSPHead, self).__init__()
|
||||
super(CrossOnlyNSPHead, self).__init__()
|
||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def forward(self, pooled_output):
|
||||
|
@ -402,10 +341,10 @@ class CrossBertOnlyNSPHead(nn.Module):
|
|||
return seq_relationship_score
|
||||
|
||||
|
||||
class CrossBertPreTrainingHeads(nn.Module):
|
||||
def __init__(self, config, cross_bert_model_embedding_weights):
|
||||
super(CrossBertPreTrainingHeads, self).__init__()
|
||||
self.predictions = CrossBertLMPredictionHead(config, cross_bert_model_embedding_weights)
|
||||
class CrossPreTrainingHeads(nn.Module):
|
||||
def __init__(self, config, cross_model_embedding_weights):
|
||||
super(CrossPreTrainingHeads, self).__init__()
|
||||
self.predictions = CrossLMPredictionHead(config, cross_model_embedding_weights)
|
||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def forward(self, sequence_output, pooled_output):
|
||||
|
@ -413,185 +352,16 @@ class CrossBertPreTrainingHeads(nn.Module):
|
|||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
class PreTrainedCrossBertModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedCrossBertModel, self).__init__()
|
||||
if not isinstance(config, CrossBertConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `CrossBertConfig`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
def init_cross_bert_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, CrossBertLayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
|
||||
'''
|
||||
abstract from `from_pretrained`
|
||||
'''
|
||||
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
|
||||
if os.path.exists(archive_file) is False:
|
||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
||||
else:
|
||||
archive_file = pretrained_model_name
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list. "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
archive_file))
|
||||
return None
|
||||
if resolved_archive_file == archive_file:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {}".format(archive_file))
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
tempdir = None
|
||||
if os.path.isdir(resolved_archive_file):
|
||||
serialization_dir = resolved_archive_file
|
||||
else:
|
||||
# Extract archive to temp dir
|
||||
tempdir = tempfile.mkdtemp()
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("extracting archive file {} to temp dir {}".format(
|
||||
resolved_archive_file, tempdir))
|
||||
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||
archive.extractall(tempdir)
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
config = CrossBertConfig.from_json_file(config_file)
|
||||
config.type_vocab_size = type_vocab_size
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Model config {}".format(config))
|
||||
|
||||
if state_dict is None:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_path):
|
||||
state_dict = torch.load(weights_path)
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Weight doesn't exsits. {}".format(weights_path))
|
||||
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
return config, state_dict
|
||||
|
||||
@classmethod
|
||||
def init_preweight(cls, model, state_dict):
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(model, prefix='' if hasattr(model, 'cross_bert') else 'cross_bert.')
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None,
|
||||
cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedCrossBertModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `cross_bert-base`
|
||||
. `cross_bert-large`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `cross-bert_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a CrossBertForPreTraining instance
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific CrossBert class
|
||||
(ex: num_classes for CrossBertForSequenceClassification)
|
||||
"""
|
||||
config, state_dict = PreTrainedCrossBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict)
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
return model
|
||||
model = PreTrainedCrossBertModel.init_preweight(model, state_dict)
|
||||
|
||||
return model
|
||||
|
||||
class CrossBertModel(PreTrainedCrossBertModel):
|
||||
class CrossModel(PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(CrossBertModel, self).__init__(config)
|
||||
self.embeddings = CrossBertEmbeddings(config)
|
||||
self.encoder = CrossBertEncoder(config)
|
||||
self.pooler = CrossBertPooler(config)
|
||||
self.apply(self.init_cross_bert_weights)
|
||||
super(CrossModel, self).__init__(config)
|
||||
self.embeddings = CrossEmbeddings(config)
|
||||
self.encoder = CrossEncoder(config)
|
||||
self.pooler = CrossPooler(config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, type, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True,
|
||||
attention_sentenceA_mask=None,
|
||||
attention_sentenceB_mask=None):
|
||||
def forward(self, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(concat_input.size(0), concat_input.size(1))
|
||||
|
@ -610,7 +380,7 @@ class CrossBertModel(PreTrainedCrossBertModel):
|
|||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(concat_input, concat_type)
|
|
@ -0,0 +1,406 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from .file_utils import cached_path
|
||||
from .until_config import PretrainedConfig
|
||||
from .until_module import PreTrainedModel, LayerNorm, ACT2FN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {}
|
||||
CONFIG_NAME = 'decoder_config.json'
|
||||
WEIGHTS_NAME = 'decoder_pytorch_model.bin'
|
||||
|
||||
|
||||
class DecoderConfig(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `DecoderModel`.
|
||||
"""
|
||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
config_name = CONFIG_NAME
|
||||
weights_name = WEIGHTS_NAME
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
max_target_embeddings=128,
|
||||
num_decoder_layers=1):
|
||||
"""Constructs DecoderConfig.
|
||||
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `DecoderModel`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
probabilities.
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`DecoderModel`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
max_target_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
num_decoder_layers:
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
elif isinstance(vocab_size_or_config_json_file, int):
|
||||
self.vocab_size = vocab_size_or_config_json_file
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.max_target_embeddings = max_target_embeddings
|
||||
self.num_decoder_layers = num_decoder_layers
|
||||
else:
|
||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)")
|
||||
|
||||
|
||||
class BertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertSelfOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
class BertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPredictionHeadTransform, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLMPredictionHead(nn.Module):
|
||||
def __init__(self, config, decoder_model_embedding_weights):
|
||||
super(BertLMPredictionHead, self).__init__()
|
||||
self.transform = BertPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(decoder_model_embedding_weights.size(1),
|
||||
decoder_model_embedding_weights.size(0),
|
||||
bias=False)
|
||||
self.decoder.weight = decoder_model_embedding_weights
|
||||
self.bias = nn.Parameter(torch.zeros(decoder_model_embedding_weights.size(0)))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states) + self.bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config, decoder_model_embedding_weights):
|
||||
super(BertOnlyMLMHead, self).__init__()
|
||||
self.predictions = BertLMPredictionHead(config, decoder_model_embedding_weights)
|
||||
|
||||
def forward(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
return prediction_scores
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
''' Multi-Head Attention module '''
|
||||
|
||||
def __init__(self, config):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, q, k, v, attention_mask):
|
||||
mixed_query_layer = self.query(q)
|
||||
mixed_key_layer = self.key(k)
|
||||
mixed_value_layer = self.value(v)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
return context_layer, attention_scores
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
''' A two-feed-forward-layer module '''
|
||||
|
||||
def __init__(self, d_in, d_hid, dropout=0.1):
|
||||
super().__init__()
|
||||
self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
|
||||
self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
|
||||
self.layer_norm = nn.LayerNorm(d_in)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
output = x.transpose(1, 2)
|
||||
output = self.w_2(ACT2FN["gelu"](self.w_1(output)))
|
||||
output = output.transpose(1, 2)
|
||||
output = self.dropout(output)
|
||||
output = self.layer_norm(output + residual)
|
||||
return output
|
||||
|
||||
class DecoderAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(DecoderAttention, self).__init__()
|
||||
self.att = MultiHeadAttention(config)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, q, k, v, attention_mask):
|
||||
att_output, attention_probs = self.att(q, k, v, attention_mask)
|
||||
attention_output = self.output(att_output, q)
|
||||
return attention_output, attention_probs
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(DecoderLayer, self).__init__()
|
||||
self.slf_attn = DecoderAttention(config)
|
||||
self.enc_attn = DecoderAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None):
|
||||
slf_output, _ = self.slf_attn(dec_input, dec_input, dec_input, slf_attn_mask)
|
||||
dec_output, dec_att_scores = self.enc_attn(slf_output, enc_output, enc_output, dec_enc_attn_mask)
|
||||
intermediate_output = self.intermediate(dec_output)
|
||||
dec_output = self.output(intermediate_output, dec_output)
|
||||
return dec_output, dec_att_scores
|
||||
|
||||
class DecoderEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
def __init__(self, config, decoder_word_embeddings_weight, decoder_position_embeddings_weight):
|
||||
super(DecoderEmbeddings, self).__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.position_embeddings = nn.Embedding(config.max_target_embeddings, config.hidden_size)
|
||||
self.word_embeddings.weight = decoder_word_embeddings_weight
|
||||
self.position_embeddings.weight = decoder_position_embeddings_weight
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids):
|
||||
seq_length = input_ids.size(1)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Decoder, self).__init__()
|
||||
layer = DecoderLayer(config)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_decoder_layers)])
|
||||
|
||||
def forward(self, hidden_states, encoder_outs, self_attn_mask, attention_mask, output_all_encoded_layers=False):
|
||||
dec_att_scores = None
|
||||
all_encoder_layers = []
|
||||
all_dec_att_probs = []
|
||||
for layer_module in self.layer:
|
||||
hidden_states, dec_att_scores = layer_module(hidden_states, encoder_outs, self_attn_mask, attention_mask)
|
||||
if output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
all_dec_att_probs.append(dec_att_scores)
|
||||
if not output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
all_dec_att_probs.append(dec_att_scores)
|
||||
return all_encoder_layers, all_dec_att_probs
|
||||
|
||||
class DecoderClassifier(nn.Module):
|
||||
def __init__(self, config, embedding_weights):
|
||||
super(DecoderClassifier, self).__init__()
|
||||
self.cls = BertOnlyMLMHead(config, embedding_weights)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
cls_scores = self.cls(hidden_states)
|
||||
return cls_scores
|
||||
|
||||
class DecoderModel(PreTrainedModel):
|
||||
|
||||
"""
|
||||
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
||||
is a :class:`TransformerDecoderLayer`.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
final_norm (bool, optional): apply layer norm to the output of the
|
||||
final decoder layer (default: True).
|
||||
"""
|
||||
|
||||
def __init__(self, config, decoder_word_embeddings_weight, decoder_position_embeddings_weight):
|
||||
super(DecoderModel, self).__init__(config)
|
||||
self.config = config
|
||||
self.max_target_length = config.max_target_embeddings
|
||||
self.embeddings = DecoderEmbeddings(config, decoder_word_embeddings_weight, decoder_position_embeddings_weight)
|
||||
self.decoder = Decoder(config)
|
||||
self.classifier = DecoderClassifier(config, decoder_word_embeddings_weight)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, input_ids, encoder_outs=None, answer_mask=None, encoder_mask=None):
|
||||
"""
|
||||
Args:
|
||||
input_ids (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing
|
||||
encoder_outs (Tensor, optional): output from the encoder, used for encoder-side attention
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- the last decoder layer's output of shape `(batch, tgt_len, vocab)`
|
||||
- the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)`
|
||||
"""
|
||||
embedding_output = self.embeddings(input_ids)
|
||||
|
||||
extended_encoder_mask = encoder_mask.unsqueeze(1).unsqueeze(2) # b x 1 x 1 x ls
|
||||
extended_encoder_mask = extended_encoder_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_encoder_mask = (1.0 - extended_encoder_mask) * -10000.0
|
||||
|
||||
extended_answer_mask = answer_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_answer_mask = extended_answer_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
|
||||
sz_b, len_s, _ = embedding_output.size()
|
||||
subsequent_mask = torch.triu(torch.ones((len_s, len_s), device=embedding_output.device, dtype=embedding_output.dtype), diagonal=1)
|
||||
self_attn_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1).unsqueeze(1) # b x 1 x ls x ls
|
||||
slf_attn_mask = ((1.0 - extended_answer_mask) + self_attn_mask).gt(0).to(dtype=self.dtype)
|
||||
self_attn_mask = slf_attn_mask * -10000.0
|
||||
|
||||
decoded_layers, dec_att_scores = self.decoder(embedding_output,
|
||||
encoder_outs,
|
||||
self_attn_mask,
|
||||
extended_encoder_mask,
|
||||
)
|
||||
sequence_output = decoded_layers[-1]
|
||||
cls_scores = self.classifier(sequence_output)
|
||||
|
||||
return cls_scores
|
|
@ -27,43 +27,27 @@ import logging
|
|||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .until_config import PretrainedConfig
|
||||
from .until_module import PreTrainedModel, LayerNorm, ACT2FN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'visual-bert-base': "[TODO]",
|
||||
'visual-bert-large': "[TODO]",
|
||||
}
|
||||
|
||||
CONFIG_NAME = 'visual_bert_config.json'
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {}
|
||||
CONFIG_NAME = 'visual_config.json'
|
||||
WEIGHTS_NAME = 'visual_pytorch_model.bin'
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
class VisualBertConfig(object):
|
||||
"""Configuration class to store the configuration of a `VisualBertModel`.
|
||||
|
||||
class VisualConfig(PretrainedConfig):
|
||||
"""Configuration class to store the configuration of a `VisualModel`.
|
||||
"""
|
||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
config_name = CONFIG_NAME
|
||||
weights_name = WEIGHTS_NAME
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file=4096,
|
||||
hidden_size=768,
|
||||
|
@ -75,7 +59,7 @@ class VisualBertConfig(object):
|
|||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02):
|
||||
"""Constructs VisualBertConfig.
|
||||
"""Constructs VisualConfig.
|
||||
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Size of the encoder layers and the pooler layer.
|
||||
|
@ -117,67 +101,18 @@ class VisualBertConfig(object):
|
|||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `VisualBertConfig` from a Python dictionary of parameters."""
|
||||
config = VisualBertConfig(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `VisualBertConfig` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as VisualBertLayerNorm
|
||||
except ImportError:
|
||||
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
||||
class VisualBertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(VisualBertLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
class VisualBertEmbeddings(nn.Module):
|
||||
class VisualEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(VisualBertEmbeddings, self).__init__()
|
||||
super(VisualEmbeddings, self).__init__()
|
||||
|
||||
self.word_embeddings = nn.Linear(config.vocab_size, config.hidden_size)
|
||||
# self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
# if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_embeddings):
|
||||
|
@ -195,10 +130,9 @@ class VisualBertEmbeddings(nn.Module):
|
|||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class VisualBertSelfAttention(nn.Module):
|
||||
class VisualSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertSelfAttention, self).__init__()
|
||||
super(VisualSelfAttention, self).__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
|
@ -230,7 +164,7 @@ class VisualBertSelfAttention(nn.Module):
|
|||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)
|
||||
# Apply the attention mask is (precomputed for all layers in VisualModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
|
@ -247,11 +181,11 @@ class VisualBertSelfAttention(nn.Module):
|
|||
return context_layer
|
||||
|
||||
|
||||
class VisualBertSelfOutput(nn.Module):
|
||||
class VisualSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertSelfOutput, self).__init__()
|
||||
super(VisualSelfOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
|
@ -261,11 +195,11 @@ class VisualBertSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class VisualBertAttention(nn.Module):
|
||||
class VisualAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertAttention, self).__init__()
|
||||
self.self = VisualBertSelfAttention(config)
|
||||
self.output = VisualBertSelfOutput(config)
|
||||
super(VisualAttention, self).__init__()
|
||||
self.self = VisualSelfAttention(config)
|
||||
self.output = VisualSelfOutput(config)
|
||||
|
||||
def forward(self, input_tensor, attention_mask):
|
||||
self_output = self.self(input_tensor, attention_mask)
|
||||
|
@ -273,9 +207,9 @@ class VisualBertAttention(nn.Module):
|
|||
return attention_output
|
||||
|
||||
|
||||
class VisualBertIntermediate(nn.Module):
|
||||
class VisualIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertIntermediate, self).__init__()
|
||||
super(VisualIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
|
@ -286,11 +220,11 @@ class VisualBertIntermediate(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class VisualBertOutput(nn.Module):
|
||||
class VisualOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertOutput, self).__init__()
|
||||
super(VisualOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
|
@ -300,12 +234,12 @@ class VisualBertOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class VisualBertLayer(nn.Module):
|
||||
class VisualLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertLayer, self).__init__()
|
||||
self.attention = VisualBertAttention(config)
|
||||
self.intermediate = VisualBertIntermediate(config)
|
||||
self.output = VisualBertOutput(config)
|
||||
super(VisualLayer, self).__init__()
|
||||
self.attention = VisualAttention(config)
|
||||
self.intermediate = VisualIntermediate(config)
|
||||
self.output = VisualOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
|
@ -314,10 +248,10 @@ class VisualBertLayer(nn.Module):
|
|||
return layer_output
|
||||
|
||||
|
||||
class VisualBertEncoder(nn.Module):
|
||||
class VisualEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertEncoder, self).__init__()
|
||||
layer = VisualBertLayer(config)
|
||||
super(VisualEncoder, self).__init__()
|
||||
layer = VisualLayer(config)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
||||
|
@ -331,9 +265,9 @@ class VisualBertEncoder(nn.Module):
|
|||
return all_encoder_layers
|
||||
|
||||
|
||||
class VisualBertPooler(nn.Module):
|
||||
class VisualPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertPooler, self).__init__()
|
||||
super(VisualPooler, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
|
@ -346,13 +280,13 @@ class VisualBertPooler(nn.Module):
|
|||
return pooled_output
|
||||
|
||||
|
||||
class VisualBertPredictionHeadTransform(nn.Module):
|
||||
class VisualPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertPredictionHeadTransform, self).__init__()
|
||||
super(VisualPredictionHeadTransform, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
self.LayerNorm = VisualBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
|
@ -361,18 +295,15 @@ class VisualBertPredictionHeadTransform(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class VisualBertLMPredictionHead(nn.Module):
|
||||
"""
|
||||
Note: It is a trap for `visual_bert_model_embedding_weights`, which is different from Embedding of BERT
|
||||
"""
|
||||
def __init__(self, config, visual_bert_model_embedding_weights):
|
||||
super(VisualBertLMPredictionHead, self).__init__()
|
||||
self.transform = VisualBertPredictionHeadTransform(config)
|
||||
class VisualLMPredictionHead(nn.Module):
|
||||
def __init__(self, config, visual_model_embedding_weights):
|
||||
super(VisualLMPredictionHead, self).__init__()
|
||||
self.transform = VisualPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.weight = visual_bert_model_embedding_weights
|
||||
self.bias = nn.Parameter(torch.zeros(visual_bert_model_embedding_weights.size(1)))
|
||||
self.weight = visual_model_embedding_weights
|
||||
self.bias = nn.Parameter(torch.zeros(visual_model_embedding_weights.size(1)))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
|
@ -380,19 +311,19 @@ class VisualBertLMPredictionHead(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class VisualBertOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config, visual_bert_model_embedding_weights):
|
||||
super(VisualBertOnlyMLMHead, self).__init__()
|
||||
self.predictions = VisualBertLMPredictionHead(config, visual_bert_model_embedding_weights)
|
||||
class VisualOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config, visual_model_embedding_weights):
|
||||
super(VisualOnlyMLMHead, self).__init__()
|
||||
self.predictions = VisualLMPredictionHead(config, visual_model_embedding_weights)
|
||||
|
||||
def forward(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class VisualBertOnlyNSPHead(nn.Module):
|
||||
class VisualOnlyNSPHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(VisualBertOnlyNSPHead, self).__init__()
|
||||
super(VisualOnlyNSPHead, self).__init__()
|
||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def forward(self, pooled_output):
|
||||
|
@ -400,10 +331,10 @@ class VisualBertOnlyNSPHead(nn.Module):
|
|||
return seq_relationship_score
|
||||
|
||||
|
||||
class VisualBertPreTrainingHeads(nn.Module):
|
||||
def __init__(self, config, visual_bert_model_embedding_weights):
|
||||
super(VisualBertPreTrainingHeads, self).__init__()
|
||||
self.predictions = VisualBertLMPredictionHead(config, visual_bert_model_embedding_weights)
|
||||
class VisualPreTrainingHeads(nn.Module):
|
||||
def __init__(self, config, visual_model_embedding_weights):
|
||||
super(VisualPreTrainingHeads, self).__init__()
|
||||
self.predictions = VisualLMPredictionHead(config, visual_model_embedding_weights)
|
||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def forward(self, sequence_output, pooled_output):
|
||||
|
@ -412,178 +343,11 @@ class VisualBertPreTrainingHeads(nn.Module):
|
|||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class PreTrainedVisualBertModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedVisualBertModel, self).__init__()
|
||||
if not isinstance(config, VisualBertConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `VisualBertConfig`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
def init_visual_bert_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, VisualBertLayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, pretrained_model_name, cache_dir, state_dict, task_config=None):
|
||||
'''
|
||||
abstract from `from_pretrained`
|
||||
'''
|
||||
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
|
||||
if os.path.exists(archive_file) is False:
|
||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
||||
else:
|
||||
archive_file = pretrained_model_name
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list. "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
archive_file))
|
||||
return None
|
||||
if resolved_archive_file == archive_file:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {}".format(archive_file))
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
tempdir = None
|
||||
if os.path.isdir(resolved_archive_file):
|
||||
serialization_dir = resolved_archive_file
|
||||
else:
|
||||
# Extract archive to temp dir
|
||||
tempdir = tempfile.mkdtemp()
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("extracting archive file {} to temp dir {}".format(
|
||||
resolved_archive_file, tempdir))
|
||||
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||
archive.extractall(tempdir)
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
config = VisualBertConfig.from_json_file(config_file)
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Model config {}".format(config))
|
||||
|
||||
if state_dict is None:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_path):
|
||||
state_dict = torch.load(weights_path)
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Weight doesn't exsits. {}".format(weights_path))
|
||||
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
return config, state_dict
|
||||
|
||||
@classmethod
|
||||
def init_preweight(cls, model, state_dict):
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(model, prefix='' if hasattr(model, 'visual_bert') else 'visual_bert.')
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None,
|
||||
cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedVisualBertModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `visual-bert-base`
|
||||
. `visual-bert-large`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `visual_bert_config.json` a configuration file for the model
|
||||
. `visual_pytorch_model.bin` a PyTorch dump of a VisualBertForPreTraining instance
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific VisualBert class
|
||||
(ex: num_classes for VisualBertForSequenceClassification)
|
||||
"""
|
||||
config, state_dict = PreTrainedVisualBertModel.get_config(pretrained_model_name, cache_dir, state_dict)
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
return model
|
||||
model = PreTrainedVisualBertModel.init_preweight(model, state_dict)
|
||||
|
||||
return model
|
||||
|
||||
class VisualBertModel(PreTrainedVisualBertModel):
|
||||
"""VisualBert model ("Bidirectional Embedding Representations from a Transformer").
|
||||
class VisualModel(PreTrainedModel):
|
||||
"""Visual model ("Bidirectional Embedding Representations from a Transformer").
|
||||
|
||||
Params:
|
||||
config: a VisualBertConfig class instance with the configuration to build a new model
|
||||
config: a VisualConfig class instance with the configuration to build a new model
|
||||
|
||||
Inputs:
|
||||
`type`: a str, indicates which masking will be used in the attention, choice from [`bi`, `seq`, `gen`]
|
||||
|
@ -592,7 +356,7 @@ class VisualBertModel(PreTrainedVisualBertModel):
|
|||
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
||||
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
||||
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
||||
a `sentence B` token (see Bert paper for more details).
|
||||
a `sentence B` token (see paper for more details).
|
||||
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
||||
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
||||
input sequence length in the current batch. It's the mask that we typically use for attention when
|
||||
|
@ -602,13 +366,13 @@ class VisualBertModel(PreTrainedVisualBertModel):
|
|||
Outputs: Tuple of (encoded_layers, pooled_output)
|
||||
`encoded_layers`: controled by `output_all_encoded_layers` argument:
|
||||
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
|
||||
of each attention block (i.e. 12 full sequences for VisualBert-base, 24 for VisualBert-large), each
|
||||
of each attention block (i.e. 12 full sequences for Visual-base, 24 for Visual-large), each
|
||||
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
||||
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
||||
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
||||
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
||||
classifier pretrained on top of the hidden state associated to the first character of the
|
||||
input (`CLF`) to train on the Next-Sentence task (see Bert's paper).
|
||||
input (`CLF`) to train on the Next-Sentence task (see 's paper).
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
|
@ -616,22 +380,21 @@ class VisualBertModel(PreTrainedVisualBertModel):
|
|||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
|
||||
config = modeling.VisualBertConfig(vocab_size_or_config_json_file=4096, hidden_size=768,
|
||||
config = modeling.VisualConfig(vocab_size_or_config_json_file=4096, hidden_size=768,
|
||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
||||
|
||||
model = modeling.VisualBertModel(config=config)
|
||||
model = modeling.VisualModel(config=config)
|
||||
all_encoder_layers, pooled_output = model(video, video_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(VisualBertModel, self).__init__(config)
|
||||
self.embeddings = VisualBertEmbeddings(config)
|
||||
self.encoder = VisualBertEncoder(config)
|
||||
self.pooler = VisualBertPooler(config)
|
||||
self.apply(self.init_visual_bert_weights)
|
||||
super(VisualModel, self).__init__(config)
|
||||
self.embeddings = VisualEmbeddings(config)
|
||||
self.encoder = VisualEncoder(config)
|
||||
self.pooler = VisualPooler(config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self, type, video, attention_mask=None, output_all_encoded_layers=True,
|
||||
attention_sentenceA_mask=None, attention_sentenceB_mask=None):
|
||||
def forward(self, video, attention_mask=None, output_all_encoded_layers=True):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(video.size(0), video.size(1))
|
||||
|
@ -648,7 +411,7 @@ class VisualBertModel(PreTrainedVisualBertModel):
|
|||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(video)
|
|
@ -110,8 +110,6 @@ class BertAdam(Optimizer):
|
|||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
warned_for_t_total = False
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
|
@ -139,8 +137,10 @@ class BertAdam(Optimizer):
|
|||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
next_m.mul_(beta1).add_(1 - beta1, grad)
|
||||
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
# next_m.mul_(beta1).add_(1 - beta1, grad) --> pytorch 1.7
|
||||
next_m.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
# next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) --> pytorch 1.7
|
||||
next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
update = next_m / (next_v.sqrt() + group['e'])
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
|
@ -157,13 +157,6 @@ class BertAdam(Optimizer):
|
|||
schedule_fct = SCHEDULES[group['schedule']]
|
||||
progress = state['step']/group['t_total']
|
||||
lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup'])
|
||||
# warning for exceeding t_total (only active with warmup_linear
|
||||
if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total:
|
||||
logger.warning(
|
||||
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
|
||||
"Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__))
|
||||
warned_for_t_total = True
|
||||
# end warning
|
||||
else:
|
||||
lr_scheduled = group['lr']
|
||||
|
||||
|
@ -172,9 +165,4 @@ class BertAdam(Optimizer):
|
|||
|
||||
state['step'] += 1
|
||||
|
||||
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
|
||||
# No bias correction
|
||||
# bias_correction1 = 1 - beta1 ** state['step']
|
||||
# bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
return loss
|
|
@ -0,0 +1,126 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import torch
|
||||
from .file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PretrainedConfig(object):
|
||||
|
||||
pretrained_model_archive_map = {}
|
||||
config_name = ""
|
||||
weights_name = ""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
|
||||
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
|
||||
if os.path.exists(archive_file) is False:
|
||||
if pretrained_model_name in cls.pretrained_model_archive_map:
|
||||
archive_file = cls.pretrained_model_archive_map[pretrained_model_name]
|
||||
else:
|
||||
archive_file = pretrained_model_name
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list. "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
archive_file))
|
||||
return None
|
||||
if resolved_archive_file == archive_file:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {}".format(archive_file))
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
tempdir = None
|
||||
if os.path.isdir(resolved_archive_file):
|
||||
serialization_dir = resolved_archive_file
|
||||
else:
|
||||
# Extract archive to temp dir
|
||||
tempdir = tempfile.mkdtemp()
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("extracting archive file {} to temp dir {}".format(
|
||||
resolved_archive_file, tempdir))
|
||||
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||
archive.extractall(tempdir)
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, cls.config_name)
|
||||
config = cls.from_json_file(config_file)
|
||||
config.type_vocab_size = type_vocab_size
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Model config {}".format(config))
|
||||
|
||||
if state_dict is None:
|
||||
weights_path = os.path.join(serialization_dir, cls.weights_name)
|
||||
if os.path.exists(weights_path):
|
||||
state_dict = torch.load(weights_path, map_location='cpu')
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Weight doesn't exsits. {}".format(weights_path))
|
||||
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
return config, state_dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
||||
config = cls(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `BertConfig` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
|
@ -0,0 +1,251 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from modules.until_config import PretrainedConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(LayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
class PreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedModel, self).__init__()
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
def init_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, LayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def init_preweight(cls, model, state_dict, prefix=None, task_config=None):
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
if prefix is not None:
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
old_keys.append(key)
|
||||
new_keys.append(prefix + key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(model, prefix='')
|
||||
|
||||
if prefix is None and (task_config is None or task_config.local_rank == 0):
|
||||
logger.info("-" * 20)
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}"
|
||||
.format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}"
|
||||
.format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
logger.error("Weights from pretrained model cause errors in {}: {}"
|
||||
.format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
|
||||
|
||||
return model
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""
|
||||
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||
"""
|
||||
try:
|
||||
return next(self.parameters()).dtype
|
||||
except StopIteration:
|
||||
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||
def find_tensor_attributes(module: nn.Module):
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].dtype
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
return model
|
||||
model = cls.init_preweight(model, state_dict)
|
||||
|
||||
return model
|
||||
|
||||
##################################
|
||||
###### LOSS FUNCTION #############
|
||||
##################################
|
||||
class CrossEn(nn.Module):
|
||||
def __init__(self,):
|
||||
super(CrossEn, self).__init__()
|
||||
|
||||
def forward(self, sim_matrix):
|
||||
logpt = F.log_softmax(sim_matrix, dim=-1)
|
||||
logpt = torch.diag(logpt)
|
||||
nce_loss = -logpt
|
||||
sim_loss = nce_loss.mean()
|
||||
return sim_loss
|
||||
|
||||
class MILNCELoss(nn.Module):
|
||||
def __init__(self, batch_size=1, n_pair=1,):
|
||||
super(MILNCELoss, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.n_pair = n_pair
|
||||
torch_v = float(".".join(torch.__version__.split(".")[:2]))
|
||||
self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8
|
||||
|
||||
def forward(self, sim_matrix):
|
||||
mm_mask = np.eye(self.batch_size)
|
||||
mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair)))
|
||||
mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device)
|
||||
|
||||
from_text_matrix = sim_matrix + mm_mask * -1e12
|
||||
from_video_matrix = sim_matrix.transpose(1, 0)
|
||||
|
||||
new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1)
|
||||
logpt = F.log_softmax(new_sim_matrix, dim=-1)
|
||||
|
||||
mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1)
|
||||
masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12
|
||||
|
||||
new_logpt = -torch.logsumexp(masked_logpt, dim=-1)
|
||||
|
||||
logpt_choice = torch.zeros_like(new_logpt)
|
||||
mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2)
|
||||
logpt_choice[mark_ind] = 1
|
||||
sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean()
|
||||
return sim_loss
|
||||
|
||||
class MaxMarginRankingLoss(nn.Module):
|
||||
def __init__(self,
|
||||
margin=1.0,
|
||||
negative_weighting=False,
|
||||
batch_size=1,
|
||||
n_pair=1,
|
||||
hard_negative_rate=0.5,
|
||||
):
|
||||
super(MaxMarginRankingLoss, self).__init__()
|
||||
self.margin = margin
|
||||
self.n_pair = n_pair
|
||||
self.batch_size = batch_size
|
||||
easy_negative_rate = 1 - hard_negative_rate
|
||||
self.easy_negative_rate = easy_negative_rate
|
||||
self.negative_weighting = negative_weighting
|
||||
if n_pair > 1 and batch_size > 1:
|
||||
alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate))
|
||||
mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha
|
||||
mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair)))
|
||||
mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate))
|
||||
self.mm_mask = mm_mask.float()
|
||||
|
||||
def forward(self, x):
|
||||
d = torch.diag(x)
|
||||
max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \
|
||||
F.relu(self.margin + x - d.view(1, -1))
|
||||
if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1:
|
||||
max_margin = max_margin * self.mm_mask.to(max_margin.device)
|
||||
return max_margin.mean()
|
|
@ -1,713 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'decoder-bert-base': "[TODO]",
|
||||
'decoder-bert-large': "[TODO]",
|
||||
}
|
||||
|
||||
CONFIG_NAME = 'decoder_bert_config.json'
|
||||
WEIGHTS_NAME = 'decoder_pytorch_model.bin'
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
class BertConfig(object):
|
||||
"""Configuration class to store the configuration of a `BertModel`.
|
||||
"""
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
max_target_embeddings=128,
|
||||
num_decoder_layers=1):
|
||||
"""Constructs BertConfig.
|
||||
|
||||
Args:
|
||||
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
probabilities.
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`BertModel`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
max_target_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
num_decoder_layers:
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
elif isinstance(vocab_size_or_config_json_file, int):
|
||||
self.vocab_size = vocab_size_or_config_json_file
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.max_target_embeddings = max_target_embeddings
|
||||
self.num_decoder_layers = num_decoder_layers
|
||||
else:
|
||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object):
|
||||
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
||||
config = BertConfig(vocab_size_or_config_json_file=-1)
|
||||
for key, value in json_object.items():
|
||||
config.__dict__[key] = value
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `BertConfig` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_json_string())
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedLayerNorm as DecoderBertLayerNorm
|
||||
except ImportError:
|
||||
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
||||
class DecoderBertLayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(DecoderBertLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
class BertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertSelfOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
class BertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertOutput, self).__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPredictionHeadTransform, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLMPredictionHead(nn.Module):
|
||||
def __init__(self, config, bert_model_embedding_weights):
|
||||
super(BertLMPredictionHead, self).__init__()
|
||||
self.transform = BertPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
||||
bert_model_embedding_weights.size(0),
|
||||
bias=False)
|
||||
self.decoder.weight = bert_model_embedding_weights
|
||||
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states) + self.bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config, bert_model_embedding_weights):
|
||||
super(BertOnlyMLMHead, self).__init__()
|
||||
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
|
||||
|
||||
def forward(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class PreTrainedDecoderBertModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedDecoderBertModel, self).__init__()
|
||||
if not isinstance(config, BertConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
def init_bert_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, DecoderBertLayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
model = self.module if hasattr(self, 'module') else self
|
||||
base_model = getattr(model, "bert", model) # get the base model if needed
|
||||
|
||||
old_embeddings = base_model.embeddings.word_embeddings
|
||||
# get_resized_embeddings <-----
|
||||
if new_num_tokens is None:
|
||||
return old_embeddings
|
||||
|
||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
if old_num_tokens == new_num_tokens:
|
||||
return old_embeddings
|
||||
|
||||
# Build new embeddings
|
||||
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
||||
device = old_embeddings.weight.device
|
||||
new_embeddings.to(device)
|
||||
|
||||
# initialize all new embeddings (in particular added tokens)
|
||||
self.init_bert_weights(new_embeddings)
|
||||
|
||||
# Copy word embeddings from the previous weights
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
||||
# ----->
|
||||
|
||||
# Update base model and current model config
|
||||
base_model.embeddings.word_embeddings = new_embeddings
|
||||
base_model.config.vocab_size = new_num_tokens
|
||||
model.config.vocab_size = new_num_tokens
|
||||
model.cls = BertOnlyMLMHead(model.config, base_model.embeddings.word_embeddings.weight).to(device)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
|
||||
'''
|
||||
abstract from `from_pretrained`
|
||||
'''
|
||||
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
|
||||
if os.path.exists(archive_file) is False:
|
||||
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
||||
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
||||
else:
|
||||
archive_file = pretrained_model_name
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list. "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
archive_file))
|
||||
return None
|
||||
if resolved_archive_file == archive_file:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {}".format(archive_file))
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("loading archive file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
tempdir = None
|
||||
if os.path.isdir(resolved_archive_file):
|
||||
serialization_dir = resolved_archive_file
|
||||
else:
|
||||
# Extract archive to temp dir
|
||||
tempdir = tempfile.mkdtemp()
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("extracting archive file {} to temp dir {}".format(
|
||||
resolved_archive_file, tempdir))
|
||||
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||
archive.extractall(tempdir)
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
||||
config = BertConfig.from_json_file(config_file)
|
||||
config.type_vocab_size = type_vocab_size
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Model config {}".format(config))
|
||||
|
||||
if state_dict is None:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_path):
|
||||
state_dict = torch.load(weights_path)
|
||||
else:
|
||||
if task_config is None or task_config.local_rank == 0:
|
||||
logger.info("Weight doesn't exsits. {}".format(weights_path))
|
||||
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
return config, state_dict
|
||||
|
||||
@classmethod
|
||||
def init_preweight(cls, model, state_dict):
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
logger.error("Weights from pretrained model cause errors in {}: {}".format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None,
|
||||
cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedDecoderBertModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name: either:
|
||||
- a str with the name of a pre-trained model to load selected in the list of:
|
||||
. `bert-base-uncased`
|
||||
. `bert-large-uncased`
|
||||
. `bert-base-cased`
|
||||
. `bert-large-cased`
|
||||
. `bert-base-multilingual-uncased`
|
||||
. `bert-base-multilingual-cased`
|
||||
. `bert-base-chinese`
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
(ex: num_classes for BertForSequenceClassification)
|
||||
"""
|
||||
config, state_dict = PreTrainedDecoderBertModel.get_config(pretrained_model_name, cache_dir, type_vocab_size, state_dict)
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
return model
|
||||
model = PreTrainedDecoderBertModel.init_preweight(model, state_dict)
|
||||
|
||||
return model
|
||||
|
||||
# import torch.nn.functional as F
|
||||
class ScaledDotProductAttention(nn.Module):
|
||||
''' Scaled Dot-Product Attention '''
|
||||
|
||||
def __init__(self, temperature, attn_dropout=0.1):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.dropout = nn.Dropout(attn_dropout)
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
|
||||
attn = torch.bmm(q, k.transpose(1, 2))
|
||||
attn = attn / self.temperature
|
||||
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask, -np.inf)
|
||||
|
||||
attn = self.softmax(attn)
|
||||
attn = self.dropout(attn)
|
||||
output = torch.bmm(attn, v)
|
||||
|
||||
return output, attn
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
''' Multi-Head Attention module '''
|
||||
|
||||
def __init__(self, config):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, q, k, v, attention_mask):
|
||||
mixed_query_layer = self.query(q)
|
||||
mixed_key_layer = self.key(k)
|
||||
mixed_value_layer = self.value(v)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
return context_layer, attention_scores
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
''' A two-feed-forward-layer module '''
|
||||
|
||||
def __init__(self, d_in, d_hid, dropout=0.1):
|
||||
super().__init__()
|
||||
self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
|
||||
self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
|
||||
self.layer_norm = nn.LayerNorm(d_in)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
output = x.transpose(1, 2)
|
||||
output = self.w_2(gelu(self.w_1(output)))
|
||||
output = output.transpose(1, 2)
|
||||
output = self.dropout(output)
|
||||
output = self.layer_norm(output + residual)
|
||||
return output
|
||||
|
||||
class DecoderAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(DecoderAttention, self).__init__()
|
||||
self.att = MultiHeadAttention(config)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, q, k, v, attention_mask):
|
||||
att_output, attention_probs = self.att(q, k, v, attention_mask)
|
||||
attention_output = self.output(att_output, q)
|
||||
return attention_output, attention_probs
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(DecoderLayer, self).__init__()
|
||||
self.slf_attn = DecoderAttention(config)
|
||||
self.enc_attn = DecoderAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None):
|
||||
slf_output, _ = self.slf_attn(dec_input, dec_input, dec_input, slf_attn_mask)
|
||||
dec_output, dec_att_scores = self.enc_attn(slf_output, enc_output, enc_output, dec_enc_attn_mask)
|
||||
intermediate_output = self.intermediate(dec_output)
|
||||
dec_output = self.output(intermediate_output, dec_output)
|
||||
return dec_output, dec_att_scores
|
||||
|
||||
class BertDecoderEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
def __init__(self, config, bert_word_embeddings_weight, bert_position_embeddings_weight):
|
||||
super(BertDecoderEmbeddings, self).__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.position_embeddings = nn.Embedding(config.max_target_embeddings, config.hidden_size)
|
||||
self.word_embeddings.weight = bert_word_embeddings_weight
|
||||
self.position_embeddings.weight = bert_position_embeddings_weight
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = DecoderBertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids):
|
||||
seq_length = input_ids.size(1)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
embeddings = words_embeddings + position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
class BertDecoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertDecoder, self).__init__()
|
||||
layer = DecoderLayer(config)
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_decoder_layers)])
|
||||
|
||||
def forward(self, hidden_states, encoder_outs, self_attn_mask, attention_mask, output_all_encoded_layers=False):
|
||||
dec_att_scores = None
|
||||
all_encoder_layers = []
|
||||
all_dec_att_probs = []
|
||||
for layer_module in self.layer:
|
||||
hidden_states, dec_att_scores = layer_module(hidden_states, encoder_outs, self_attn_mask, attention_mask)
|
||||
if output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
all_dec_att_probs.append(dec_att_scores)
|
||||
if not output_all_encoded_layers:
|
||||
all_encoder_layers.append(hidden_states)
|
||||
all_dec_att_probs.append(dec_att_scores)
|
||||
return all_encoder_layers, all_dec_att_probs
|
||||
|
||||
class BertDecoderClassifier(nn.Module):
|
||||
def __init__(self, config, embedding_weights):
|
||||
super(BertDecoderClassifier, self).__init__()
|
||||
self.cls = BertOnlyMLMHead(config, embedding_weights)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
cls_scores = self.cls(hidden_states)
|
||||
return cls_scores
|
||||
|
||||
class DecoderBertModel(nn.Module):
|
||||
|
||||
def init_decoder_bert_weights(self, module):
|
||||
""" Initialize the weights.
|
||||
"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, DecoderBertLayerNorm):
|
||||
if 'beta' in dir(module) and 'gamma' in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
"""
|
||||
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
||||
is a :class:`TransformerDecoderLayer`.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
final_norm (bool, optional): apply layer norm to the output of the
|
||||
final decoder layer (default: True).
|
||||
"""
|
||||
|
||||
def __init__(self, config, bert_word_embeddings_weight, bert_position_embeddings_weight):
|
||||
super(DecoderBertModel, self).__init__()
|
||||
self.config = config
|
||||
self.max_target_length = config.max_target_embeddings
|
||||
self.embeddings = BertDecoderEmbeddings(config, bert_word_embeddings_weight, bert_position_embeddings_weight)
|
||||
self.decoder = BertDecoder(config)
|
||||
self.classifier = BertDecoderClassifier(config, bert_word_embeddings_weight)
|
||||
self.apply(self.init_decoder_bert_weights)
|
||||
|
||||
def forward(self, input_ids, encoder_input_ids, encoder_outs=None, answer_mask=None, encoder_mask=None):
|
||||
"""
|
||||
Args:
|
||||
input_ids (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing
|
||||
encoder_outs (Tensor, optional): output from the encoder, used for encoder-side attention
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- the last decoder layer's output of shape `(batch, tgt_len, vocab)`
|
||||
- the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)`
|
||||
"""
|
||||
embedding_output = self.embeddings(input_ids)
|
||||
|
||||
extended_encoder_mask = encoder_mask.unsqueeze(1).unsqueeze(2) # b x 1 x 1 x ls
|
||||
extended_encoder_mask = extended_encoder_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_encoder_mask = (1.0 - extended_encoder_mask) * -10000.0
|
||||
|
||||
extended_answer_mask = answer_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_answer_mask = extended_answer_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
|
||||
sz_b, len_s, _ = embedding_output.size()
|
||||
subsequent_mask = torch.triu(torch.ones((len_s, len_s), device=embedding_output.device, dtype=embedding_output.dtype), diagonal=1)
|
||||
self_attn_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1).unsqueeze(1) # b x 1 x ls x ls
|
||||
slf_attn_mask = ((1.0 - extended_answer_mask) + self_attn_mask).gt(0).to(dtype=next(self.parameters()).dtype)
|
||||
self_attn_mask = slf_attn_mask * -10000.0
|
||||
|
||||
decoded_layers, dec_att_scores = self.decoder(embedding_output,
|
||||
encoder_outs,
|
||||
self_attn_mask,
|
||||
extended_encoder_mask,
|
||||
)
|
||||
sequence_output = decoded_layers[-1]
|
||||
cls_scores = self.classifier(sequence_output)
|
||||
|
||||
return cls_scores
|
|
@ -0,0 +1,5 @@
|
|||
torch==1.7.0
|
||||
tqdm
|
||||
boto3
|
||||
requests
|
||||
pandas
|
16
util.py
16
util.py
|
@ -2,6 +2,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import threading
|
||||
from torch._utils import ExceptionWrapper
|
||||
import logging
|
||||
|
||||
def get_a_var(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
|
@ -56,4 +57,17 @@ def parallel_apply(fct, model, inputs, device_ids):
|
|||
if isinstance(output, ExceptionWrapper):
|
||||
output.reraise()
|
||||
outputs.append(output)
|
||||
return outputs
|
||||
return outputs
|
||||
|
||||
def get_logger(filename=None):
|
||||
logger = logging.getLogger('logger')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
level=logging.INFO)
|
||||
if filename is not None:
|
||||
handler = logging.FileHandler(filename)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
|
||||
logging.getLogger().addHandler(handler)
|
||||
return logger
|
Загрузка…
Ссылка в новой задаче