Merge pull request #2115 from suvrat96/add_mmbt_model
[WIP] Add MMBT Model to Transformers Repo
This commit is contained in:
Коммит
73f6e9817c
|
@ -148,7 +148,8 @@ At some point in the future, you'll be able to seamlessly move from pre-training
|
|||
11. **[ALBERT](https://github.com/google-research/ALBERT)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.
|
||||
12. **[T5](https://github.com/google-research/text-to-text-transfer-transformer)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
13. **[XLM-RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/xlmr)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
|
||||
14. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||
14. **[MMBT](https://github.com/facebookresearch/mmbt/)** (from Facebook), released together with the paper a [Supervised Multimodal Bitransformers for Classifying Images and Text](https://arxiv.org/pdf/1909.02950.pdf) by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine.
|
||||
15. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||
|
||||
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).
|
||||
|
||||
|
|
|
@ -734,3 +734,28 @@ Training with the previously defined hyper-parameters yields the following resul
|
|||
```bash
|
||||
acc = 0.7093812375249501
|
||||
```
|
||||
|
||||
## MM-IMDb
|
||||
|
||||
Based on the script [`run_mmimdb.py`](https://github.com/huggingface/transformers/blob/master/examples/mm-imdb/run_mmimdb.py).
|
||||
|
||||
[MM-IMDb](http://lisi1.unal.edu.co/mmimdb/) is a Multimodal dataset with around 26,000 movies including images, plots and other metadata.
|
||||
|
||||
### Training on MM-IMDb
|
||||
|
||||
```
|
||||
python run_mmimdb.py \
|
||||
--data_dir /path/to/mmimdb/dataset/ \
|
||||
--model_type bert \
|
||||
--model_name_or_path bert-base-uncased \
|
||||
--output_dir /path/to/save/dir/ \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--max_seq_len 512 \
|
||||
--gradient_accumulation_steps 20 \
|
||||
--num_image_embeds 3 \
|
||||
--num_train_epochs 100 \
|
||||
--patience 5
|
||||
```
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,504 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# Copyright (c) HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning the library models for multimodal multiclass prediction on MM-IMDB dataset."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import json
|
||||
from sklearn.metrics import f1_score
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_mmimdb_labels, get_image_transforms
|
||||
|
||||
from transformers import (WEIGHTS_NAME,
|
||||
BertConfig, BertModel, BertTokenizer,
|
||||
RobertaConfig, RobertaModel, RobertaTokenizer,
|
||||
XLMConfig, XLMModel, XLMTokenizer,
|
||||
XLNetConfig, XLNetModel, XLNetTokenizer,
|
||||
DistilBertConfig, DistilBertModel, DistilBertTokenizer,
|
||||
AlbertConfig, AlbertModel, AlbertTokenizer,
|
||||
MMBTForClassification, MMBTConfig)
|
||||
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig,
|
||||
RobertaConfig, DistilBertConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertModel, BertTokenizer),
|
||||
'xlnet': (XLNetConfig, XLNetModel, XLNetTokenizer),
|
||||
'xlm': (XLMConfig, XLMModel, XLMTokenizer),
|
||||
'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer),
|
||||
'distilbert': (DistilBertConfig, DistilBertModel, DistilBertTokenizer),
|
||||
'albert': (AlbertConfig, AlbertModel, AlbertTokenizer)
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, criterion):
|
||||
""" Train the model """
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler,
|
||||
batch_size=args.train_batch_size,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=args.num_workers)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||
|
||||
# Prepare optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ['bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
best_f1, n_no_improve = 0, 0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
|
||||
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
labels = batch[5]
|
||||
inputs = {'input_ids': batch[0],
|
||||
'input_modal': batch[2],
|
||||
'attention_mask': batch[1],
|
||||
'modal_start_tokens': batch[3],
|
||||
'modal_end_tokens': batch[4]}
|
||||
outputs = model(**inputs)
|
||||
logits = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
loss = criterion(logits, labels)
|
||||
|
||||
if args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
logs = {}
|
||||
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
results = evaluate(args, model, tokenizer, criterion)
|
||||
for key, value in results.items():
|
||||
eval_key = 'eval_{}'.format(key)
|
||||
logs[eval_key] = value
|
||||
|
||||
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
||||
learning_rate_scalar = scheduler.get_lr()[0]
|
||||
logs['learning_rate'] = learning_rate_scalar
|
||||
logs['loss'] = loss_scalar
|
||||
logging_loss = tr_loss
|
||||
|
||||
for key, value in logs.items():
|
||||
tb_writer.add_scalar(key, value, global_step)
|
||||
print(json.dumps({**logs, **{'step': global_step}}))
|
||||
|
||||
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
|
||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||
logger.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank == -1:
|
||||
results = evaluate(args, model, tokenizer, criterion)
|
||||
if results['micro_f1'] > best_f1:
|
||||
best_f1 = results['micro_f1']
|
||||
n_no_improve = 0
|
||||
else:
|
||||
n_no_improve += 1
|
||||
|
||||
if n_no_improve > args.patience:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer.close()
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, criterion, prefix=""):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_output_dir = args.output_dir
|
||||
eval_dataset = load_examples(args, tokenizer, evaluate=True)
|
||||
|
||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(eval_output_dir)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn)
|
||||
|
||||
# multi-gpu eval
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
labels = batch[5]
|
||||
inputs = {'input_ids': batch[0],
|
||||
'input_modal': batch[2],
|
||||
'attention_mask': batch[1],
|
||||
'modal_start_tokens': batch[3],
|
||||
'modal_end_tokens': batch[4]}
|
||||
outputs = model(**inputs)
|
||||
logits = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
tmp_eval_loss = criterion(logits, labels)
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = torch.sigmoid(logits).detach().cpu().numpy() > 0.5
|
||||
out_label_ids = labels.detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, torch.sigmoid(logits).detach().cpu().numpy() > 0.5, axis=0)
|
||||
out_label_ids = np.append(out_label_ids, labels.detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
result = {
|
||||
"loss": eval_loss,
|
||||
"macro_f1": f1_score(out_label_ids, preds, average="macro"),
|
||||
"micro_f1": f1_score(out_label_ids, preds, average="micro")
|
||||
}
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def load_examples(args, tokenizer, evaluate=False):
|
||||
path = os.path.join(args.data_dir, "dev.jsonl" if evaluate else "train.jsonl")
|
||||
transforms = get_image_transforms()
|
||||
labels = get_mmimdb_labels()
|
||||
dataset = JsonlDataset(path, tokenizer, transforms, labels, args.max_seq_length - args.num_image_embeds - 2)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
||||
help="The input data dir. Should contain the .jsonl files for MMIMDB.")
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--cache_dir", default="", type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.")
|
||||
parser.add_argument("--num_image_embeds", default=1, type=int,
|
||||
help="Number of Image Embeddings from the Image Encoder")
|
||||
parser.add_argument("--do_train", action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action='store_true',
|
||||
help="Whether to run eval on the dev set.")
|
||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||
help="Rul evaluation during training at each logging step.")
|
||||
parser.add_argument("--do_lower_case", action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
|
||||
help="Batch size per GPU/CPU for evaluation.")
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--patience", default=5, type=int,
|
||||
help="Patience for Early Stopping.")
|
||||
parser.add_argument("--max_steps", default=-1, type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int,
|
||||
help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument('--logging_steps', type=int, default=50,
|
||||
help="Log every X updates steps.")
|
||||
parser.add_argument('--save_steps', type=int, default=50,
|
||||
help="Save checkpoint every X updates steps.")
|
||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
parser.add_argument("--no_cuda", action='store_true',
|
||||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--num_workers', type=int, default=8,
|
||||
help="number of worker threads for dataloading")
|
||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||
help="Overwrite the content of the output directory")
|
||||
parser.add_argument('--overwrite_cache', action='store_true',
|
||||
help="Overwrite the cached training and evaluation sets")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument("--local_rank", type=int, default=-1,
|
||||
help="For distributed training: local_rank")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
args.n_gpu = 1
|
||||
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
# Setup model
|
||||
labels = get_mmimdb_labels()
|
||||
num_labels = len(labels)
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
transformer_config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
transformer = model_class.from_pretrained(args.model_name_or_path,
|
||||
config=transformer_config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None)
|
||||
img_encoder = ImageEncoder(args)
|
||||
config = MMBTConfig(transformer_config, num_labels=num_labels)
|
||||
model = MMBTForClassification(config, transformer, img_encoder)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
train_dataset = load_examples(args, tokenizer, evaluate=False)
|
||||
label_frequences = train_dataset.get_label_frequencies()
|
||||
label_frequences = [label_frequences[l] for l in labels]
|
||||
label_weights = (torch.tensor(label_frequences, device=args.device, dtype=torch.float) / len(train_dataset)) ** -1
|
||||
criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer, criterion)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
# Create output directory if needed
|
||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
|
||||
torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME))
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = MMBTForClassification(config, transformer, img_encoder)
|
||||
model.load_state_dict(torch.load(os.path.join(args.output_dir, WEIGHTS_NAME)))
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
|
||||
model = MMBTForClassification(config, transformer, img_encoder)
|
||||
model.load_state_dict(torch.load(checkpoint))
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, criterion, prefix=prefix)
|
||||
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
|
||||
results.update(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,130 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# Copyright (c) HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import Counter
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
POOLING_BREAKDOWN = {
|
||||
1: (1, 1),
|
||||
2: (2, 1),
|
||||
3: (3, 1),
|
||||
4: (2, 2),
|
||||
5: (5, 1),
|
||||
6: (3, 2),
|
||||
7: (7, 1),
|
||||
8: (4, 2),
|
||||
9: (3, 3)
|
||||
}
|
||||
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(ImageEncoder, self).__init__()
|
||||
model = torchvision.models.resnet152(pretrained=True)
|
||||
modules = list(model.children())[:-2]
|
||||
self.model = nn.Sequential(*modules)
|
||||
self.pool = nn.AdaptiveAvgPool2d(POOLING_BREAKDOWN[args.num_image_embeds])
|
||||
|
||||
def forward(self, x):
|
||||
# Bx3x224x224 -> Bx2048x7x7 -> Bx2048xN -> BxNx2048
|
||||
out = self.pool(self.model(x))
|
||||
out = torch.flatten(out, start_dim=2)
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
return out # BxNx2048
|
||||
|
||||
|
||||
|
||||
class JsonlDataset(Dataset):
|
||||
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length):
|
||||
self.data = [json.loads(l) for l in open(data_path)]
|
||||
self.data_dir = os.path.dirname(data_path)
|
||||
self.tokenizer = tokenizer
|
||||
self.labels = labels
|
||||
self.n_classes = len(labels)
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
self.transforms = transforms
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True))
|
||||
start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
|
||||
sentence = sentence[:self.max_seq_length]
|
||||
|
||||
label = torch.zeros(self.n_classes)
|
||||
label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1
|
||||
|
||||
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
|
||||
image = self.transforms(image)
|
||||
|
||||
return {"image_start_token": start_token, "image_end_token": end_token,
|
||||
"sentence": sentence, "image": image, "label": label}
|
||||
|
||||
def get_label_frequencies(self):
|
||||
label_freqs = Counter()
|
||||
for row in self.data:
|
||||
label_freqs.update(row["label"])
|
||||
return label_freqs
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
lens = [len(row["sentence"]) for row in batch]
|
||||
bsz, max_seq_len = len(batch), max(lens)
|
||||
|
||||
mask_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)
|
||||
text_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)
|
||||
|
||||
for i_batch, (input_row, length) in enumerate(zip(batch, lens)):
|
||||
text_tensor[i_batch, :length] = input_row["sentence"]
|
||||
mask_tensor[i_batch, :length] = 1
|
||||
|
||||
img_tensor = torch.stack([row["image"] for row in batch])
|
||||
tgt_tensor = torch.stack([row["label"] for row in batch])
|
||||
img_start_token = torch.stack([row["image_start_token"] for row in batch])
|
||||
img_end_token = torch.stack([row["image_end_token"] for row in batch])
|
||||
|
||||
return text_tensor, mask_tensor, img_tensor, img_start_token, img_end_token, tgt_tensor
|
||||
|
||||
|
||||
def get_mmimdb_labels():
|
||||
return ['Crime', 'Drama', 'Thriller', 'Action', 'Comedy', 'Romance',
|
||||
'Documentary', 'Short', 'Mystery', 'History', 'Family', 'Adventure',
|
||||
'Fantasy', 'Sci-Fi', 'Western', 'Horror', 'Sport', 'War', 'Music',
|
||||
'Musical', 'Animation', 'Biography', 'Film-Noir']
|
||||
|
||||
|
||||
def get_image_transforms():
|
||||
return transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.46777044, 0.44531429, 0.40661017],
|
||||
std=[0.12221994, 0.12145835, 0.14380469],
|
||||
),
|
||||
]
|
||||
)
|
|
@ -71,6 +71,7 @@ from .configuration_albert import AlbertConfig, ALBERT_PRETRAINED_CONFIG_ARCHIVE
|
|||
from .configuration_camembert import CamembertConfig, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from .configuration_t5 import T5Config, T5_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from .configuration_xlm_roberta import XLMRobertaConfig, XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from .configuration_mmbt import MMBTConfig
|
||||
|
||||
# Modeling
|
||||
if is_torch_available():
|
||||
|
@ -120,13 +121,12 @@ if is_torch_available():
|
|||
from .modeling_t5 import (T5PreTrainedModel, T5Model, T5WithLMHeadModel,
|
||||
load_tf_weights_in_t5,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
from .modeling_albert import (AlbertPreTrainedModel, AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification,
|
||||
AlbertForQuestionAnswering,
|
||||
load_tf_weights_in_albert, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
from .modeling_xlm_roberta import (XLMRobertaForMaskedLM, XLMRobertaModel, XLMRobertaForMultipleChoice,
|
||||
XLMRobertaForSequenceClassification, XLMRobertaForTokenClassification)
|
||||
from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification
|
||||
|
||||
# Optimization
|
||||
from .optimization import (AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup,
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# Copyright (c) HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" MMBT configuration """
|
||||
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MMBTConfig(object):
|
||||
"""Configuration class to store the configuration of a `MMBT Model`.
|
||||
|
||||
Args:
|
||||
config: config of the underlying Transformer models. It's values are copied over to use a single config.
|
||||
num_labels: Size of final Linear layer for classification.
|
||||
modal_hidden_size: Embedding dimension of the non-text modality encoder.
|
||||
"""
|
||||
def __init__(self, config, num_labels=None, modal_hidden_size=2048):
|
||||
self.__dict__ = config.__dict__
|
||||
self.modal_hidden_size = modal_hidden_size
|
||||
if num_labels:
|
||||
self.num_labels = num_labels
|
|
@ -0,0 +1,368 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# Copyright (c) HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch MMBT model. """
|
||||
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModalEmbeddings(nn.Module):
|
||||
"""Generic Modal Embeddings which takes in an encoder, and a transformer embedding.
|
||||
"""
|
||||
def __init__(self, config, encoder, embeddings):
|
||||
super(ModalEmbeddings, self).__init__()
|
||||
self.config = config
|
||||
self.encoder = encoder
|
||||
self.proj_embeddings = nn.Linear(config.modal_hidden_size, config.hidden_size)
|
||||
self.position_embeddings = embeddings.position_embeddings
|
||||
self.token_type_embeddings = embeddings.token_type_embeddings
|
||||
self.word_embeddings = embeddings.word_embeddings
|
||||
self.LayerNorm = embeddings.LayerNorm
|
||||
self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_modal, start_token=None, end_token=None, position_ids=None, token_type_ids=None):
|
||||
token_embeddings = self.proj_embeddings(self.encoder(input_modal))
|
||||
seq_length = token_embeddings.size(1)
|
||||
|
||||
if start_token is not None:
|
||||
start_token_embeds = self.word_embeddings(start_token)
|
||||
seq_length += 1
|
||||
token_embeddings = torch.cat([start_token_embeds.unsqueeze(1), token_embeddings], dim=1)
|
||||
|
||||
if end_token is not None:
|
||||
end_token_embeds = self.word_embeddings(end_token)
|
||||
seq_length += 1
|
||||
token_embeddings = torch.cat([token_embeddings, end_token_embeds.unsqueeze(1)], dim=1)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_modal.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(input_modal.size(0), seq_length)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros((input_modal.size(0), seq_length), dtype=torch.long, device=input_modal.device)
|
||||
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = token_embeddings + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
MMBT_START_DOCSTRING = r""" MMBT model was proposed in
|
||||
`Supervised Multimodal Bitransformers for Classifying Images and Text`_
|
||||
by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine.
|
||||
It's a supervised multimodal bitransformer model that fuses information from text and other image encoders,
|
||||
and obtain state-of-the-art performance on various multimodal classification benchmark tasks.
|
||||
|
||||
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
|
||||
refer to the PyTorch documentation for all matter related to general usage and behavior.
|
||||
|
||||
.. _`Supervised Multimodal Bitransformers for Classifying Images and Text`:
|
||||
https://www.github.com/salesforce/ctrl
|
||||
|
||||
.. _`torch.nn.Module`:
|
||||
https://pytorch.org/docs/stable/nn.html#module
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.MMBTConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||
transformer (:class: `~nn.Module`): A text transformer that is used by MMBT.
|
||||
It should have embeddings, encoder, and pooler attributes.
|
||||
encoder (:class: `~nn.Module`): Encoder for the second modality.
|
||||
It should take in a batch of modal inputs and return k, n dimension embeddings.
|
||||
"""
|
||||
|
||||
MMBT_INPUTS_DOCSTRING = r""" Inputs:
|
||||
**input_modal**: ``torch.FloatTensor`` of shape ``(batch_size, ***)``:
|
||||
The other modality data. It will be the shape that the encoder for that type expects.
|
||||
e.g. With an Image Encoder, the shape would be (batch_size, channels, height, width)
|
||||
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
It does not expect [CLS] token to be added as it's appended to the end of other modality embeddings.
|
||||
See :func:`transformers.PreTrainedTokenizer.encode` and
|
||||
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
|
||||
**modal_start_tokens**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Optional start token to be added to Other Modality Embedding. [CLS] Most commonly used for Classification tasks.
|
||||
**modal_end_tokens**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Optional end token to be added to Other Modality Embedding. [SEP] Most commonly used.
|
||||
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Segment token indices to indicate different portions of the inputs.
|
||||
**modal_token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``:
|
||||
Segment token indices to indicate different portions of the non-text modality.
|
||||
The embeddings from these tokens will be summed with the respective token embeddings for the non-text modality.
|
||||
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
**modal_position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``:
|
||||
Indices of positions of each input sequence tokens in the position embeddings for the non-text modality.
|
||||
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
|
||||
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
|
||||
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
**encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model
|
||||
is configured as a decoder.
|
||||
**encoder_attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
||||
is used in the cross-attention if the model is configured as a decoder.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
"""
|
||||
|
||||
@add_start_docstrings("The bare MMBT Model outputting raw hidden-states without any specific head on top.",
|
||||
MMBT_START_DOCSTRING, MMBT_INPUTS_DOCSTRING)
|
||||
class MMBTModel(nn.Module):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||
layer weights are trained from the next sentence prediction (classification)
|
||||
objective during Bert pretraining. This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
transformer = BertModel.from_pretrained('bert-base-uncased')
|
||||
encoder = ImageEncoder(args)
|
||||
mmbt = MMBTModel(config, transformer, encoder)
|
||||
"""
|
||||
def __init__(self, config, transformer, encoder):
|
||||
super(MMBTModel, self).__init__()
|
||||
self.config = config
|
||||
self.transformer = transformer
|
||||
self.modal_encoder = ModalEmbeddings(config, encoder, transformer.embeddings)
|
||||
|
||||
def forward(self, input_modal, input_ids=None, modal_start_tokens=None,
|
||||
modal_end_tokens=None, attention_mask=None,
|
||||
token_type_ids=None, modal_token_type_ids=None,
|
||||
position_ids=None, modal_position_ids=None, head_mask=None,
|
||||
inputs_embeds=None, encoder_hidden_states=None,
|
||||
encoder_attention_mask=None):
|
||||
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_txt_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_txt_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
modal_embeddings = self.modal_encoder(input_modal,
|
||||
start_token=modal_start_tokens,
|
||||
end_token=modal_end_tokens,
|
||||
position_ids=modal_position_ids,
|
||||
token_type_ids=modal_token_type_ids)
|
||||
|
||||
input_modal_shape = modal_embeddings.size()[:-1]
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.ones(input_txt_shape, dtype=torch.long, device=device)
|
||||
|
||||
txt_embeddings = self.transformer.embeddings(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
embedding_output = torch.cat([modal_embeddings, txt_embeddings], 1)
|
||||
|
||||
input_shape = embedding_output.size()[:-1]
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
else:
|
||||
attention_mask = torch.cat([torch.ones(input_modal_shape, device=device, dtype=torch.long), attention_mask], dim=1)
|
||||
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(input_shape, device=device)
|
||||
else:
|
||||
encoder_attention_mask = torch.cat([torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask[:, None, :, :]
|
||||
|
||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if attention_mask.dim() == 2:
|
||||
if self.config.is_decoder:
|
||||
batch_size, seq_length = input_shape
|
||||
seq_ids = torch.arange(seq_length, device=device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.dim() == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
|
||||
encoder_outputs = self.transformer.encoder(embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.transformer.pooler(sequence_output)
|
||||
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
|
||||
@add_start_docstrings("""MMBT Model with a sequence classification/regression head on top (a linear layer on top of
|
||||
the pooled output)""", MMBT_START_DOCSTRING, MMBT_INPUTS_DOCSTRING)
|
||||
class MMBTForClassification(nn.Module):
|
||||
r"""
|
||||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
||||
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
||||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
transformer = BertModel.from_pretrained('bert-base-uncased')
|
||||
encoder = ImageEncoder(args)
|
||||
model = MMBTForClassification(config, transformer, encoder)
|
||||
outputs = model(input_modal, input_ids, labels=labels)
|
||||
loss, logits = outputs[:2]
|
||||
"""
|
||||
|
||||
def __init__(self, config, transformer, encoder):
|
||||
super(MMBTForClassification, self).__init__()
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.mmbt = MMBTModel(config, transformer, encoder)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, input_modal, input_ids=None, modal_start_tokens=None, modal_end_tokens=None,
|
||||
attention_mask=None, token_type_ids=None, modal_token_type_ids=None, position_ids=None,
|
||||
modal_position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
|
||||
|
||||
outputs = self.mmbt(input_modal=input_modal, input_ids=input_ids,
|
||||
modal_start_tokens=modal_start_tokens,
|
||||
modal_end_tokens=modal_end_tokens,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
modal_token_type_ids=modal_token_type_ids,
|
||||
position_ids=position_ids,
|
||||
modal_position_ids=modal_position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
Загрузка…
Ссылка в новой задаче