from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from __future__ import print_function import torch from torch.utils.data import (SequentialSampler) import numpy as np import random import os from metrics import compute_metrics import pickle import logging import time import argparse from modules.tokenization import BertTokenizer from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from modules.modeling import UniVL from modules.optimization import BertAdam from torch.utils.data import DataLoader from util import parallel_apply, get_logger from dataloader_youcook_retrieval import Youcook_DataLoader from dataloader_msrvtt_retrieval import MSRVTT_DataLoader from dataloader_msrvtt_retrieval import MSRVTT_TrainDataLoader torch.distributed.init_process_group(backend="nccl") global logger def get_args(description='UniVL on Retrieval Task'): parser = argparse.ArgumentParser(description=description) parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument('--train_csv', type=str, default='data/youcookii_singlef_train.csv', help='') parser.add_argument('--val_csv', type=str, default='data/youcookii_singlef_val.csv', help='') parser.add_argument('--data_path', type=str, default='data/youcookii_caption.pickle', help='data pickle file path') parser.add_argument('--features_path', type=str, default='data/youcookii_videos_feature.pickle', help='feature path') parser.add_argument('--num_thread_reader', type=int, default=1, help='') parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate') parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit') parser.add_argument('--batch_size', type=int, default=256, help='batch size') parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval') parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay') parser.add_argument('--n_display', type=int, default=100, help='Information display frequence') parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension') parser.add_argument('--seed', type=int, default=42, help='random seed') parser.add_argument('--max_words', type=int, default=20, help='') parser.add_argument('--max_frames', type=int, default=100, help='') parser.add_argument('--feature_framerate', type=int, default=1, help='') parser.add_argument('--margin', type=float, default=0.1, help='margin for loss') parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample') parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative') parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader') parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.") parser.add_argument("--bert_model", default="bert-base-uncased", type=str, required=True, help="Bert pre-trained model") parser.add_argument("--visual_model", default="visual-base", type=str, required=False, help="Visual module") parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module") parser.add_argument("--decoder_model", default="decoder-base", type=str, required=False, help="Decoder module") parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.") parser.add_argument("--cache_dir", default="", type=str, help="Where do you want to store the pre-trained models downloaded from s3") parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") parser.add_argument('--fp16_opt_level', type=str, default='O1', help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") parser.add_argument("--task_type", default="retrieval", type=str, help="Point the task `retrieval` to finetune.") parser.add_argument("--datatype", default="youcook", type=str, help="Point the dataset `youcook` to finetune.") parser.add_argument("--world_size", default=0, type=int, help="distribted training") parser.add_argument("--local_rank", default=0, type=int, help="distribted training") parser.add_argument('--coef_lr', type=float, default=0.1, help='coefficient for bert branch.') parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).") parser.add_argument('--sampled_use_mil', action='store_true', help="Whether MIL, has a high priority than use_mil.") parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.") parser.add_argument('--visual_num_hidden_layers', type=int, default=6, help="Layer NO. of visual.") parser.add_argument('--cross_num_hidden_layers', type=int, default=2, help="Layer NO. of cross.") parser.add_argument('--decoder_num_hidden_layers', type=int, default=3, help="Layer NO. of decoder.") parser.add_argument('--train_sim_after_cross', action='store_true', help="Test retrieval after cross encoder.") parser.add_argument('--expand_msrvtt_sentences', action='store_true', help="") args = parser.parse_args() # Check paramenters if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) if not args.do_train and not args.do_eval: raise ValueError("At least one of `do_train` or `do_eval` must be True.") args.batch_size = int(args.batch_size / args.gradient_accumulation_steps) return args def set_seed_logger(args): global logger # predefining random initial seeds random.seed(args.seed) os.environ['PYTHONHASHSEED'] = str(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True world_size = torch.distributed.get_world_size() torch.cuda.set_device(args.local_rank) args.world_size = world_size if not os.path.exists(args.output_dir): os.makedirs(args.output_dir, exist_ok=True) logger = get_logger(os.path.join(args.output_dir, "log.txt")) if args.local_rank == 0: logger.info("Effective parameters:") for key in sorted(args.__dict__): logger.info(" <<< {}: {}".format(key, args.__dict__[key])) return args def init_device(args, local_rank): global logger device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank) n_gpu = torch.cuda.device_count() logger.info("device: {} n_gpu: {}".format(device, n_gpu)) args.n_gpu = n_gpu if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0: raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format( args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu)) return device, n_gpu def init_model(args, device, n_gpu, local_rank): if args.init_model: model_state_dict = torch.load(args.init_model, map_location='cpu') else: model_state_dict = None # Prepare model cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed') model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args) model.to(device) return model def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.): if hasattr(model, 'module'): model = model.module param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)] decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)] no_decay_bert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." in n] no_decay_nobert_param_tp = [(n, p) for n, p in no_decay_param_tp if "bert." not in n] decay_bert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." in n] decay_nobert_param_tp = [(n, p) for n, p in decay_param_tp if "bert." not in n] optimizer_grouped_parameters = [ {'params': [p for n, p in no_decay_bert_param_tp], 'weight_decay': 0.01, 'lr': args.lr * coef_lr}, {'params': [p for n, p in no_decay_nobert_param_tp], 'weight_decay': 0.01}, {'params': [p for n, p in decay_bert_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr}, {'params': [p for n, p in decay_nobert_param_tp], 'weight_decay': 0.0} ] scheduler = None optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion, schedule='warmup_linear', t_total=num_train_optimization_steps, weight_decay=0.01, max_grad_norm=1.0) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) return optimizer, scheduler, model def dataloader_youcook_train(args, tokenizer): youcook_dataset = Youcook_DataLoader( csv=args.train_csv, data_path=args.data_path, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, ) train_sampler = torch.utils.data.distributed.DistributedSampler(youcook_dataset) dataloader = DataLoader( youcook_dataset, batch_size=args.batch_size // args.n_gpu, num_workers=args.num_thread_reader, pin_memory=False, shuffle=(train_sampler is None), sampler=train_sampler, drop_last=True, ) return dataloader, len(youcook_dataset), train_sampler def dataloader_youcook_test(args, tokenizer): youcook_testset = Youcook_DataLoader( csv=args.val_csv, data_path=args.data_path, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, ) test_sampler = SequentialSampler(youcook_testset) dataloader_youcook = DataLoader( youcook_testset, sampler=test_sampler, batch_size=args.batch_size_val, num_workers=args.num_thread_reader, pin_memory=False, ) logger.info('YoucookII validation pairs: {}'.format(len(youcook_testset))) return dataloader_youcook, len(youcook_testset) def dataloader_msrvtt_train(args, tokenizer): msrvtt_dataset = MSRVTT_TrainDataLoader( csv_path=args.train_csv, json_path=args.data_path, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, unfold_sentences=args.expand_msrvtt_sentences, ) train_sampler = torch.utils.data.distributed.DistributedSampler(msrvtt_dataset) dataloader = DataLoader( msrvtt_dataset, batch_size=args.batch_size // args.n_gpu, num_workers=args.num_thread_reader, pin_memory=False, shuffle=(train_sampler is None), sampler=train_sampler, drop_last=True, ) return dataloader, len(msrvtt_dataset), train_sampler def dataloader_msrvtt_test(args, tokenizer): msrvtt_testset = MSRVTT_DataLoader( csv_path=args.val_csv, features_path=args.features_path, max_words=args.max_words, feature_framerate=args.feature_framerate, tokenizer=tokenizer, max_frames=args.max_frames, ) dataloader_msrvtt = DataLoader( msrvtt_testset, batch_size=args.batch_size_val, num_workers=args.num_thread_reader, shuffle=False, drop_last=False, ) return dataloader_msrvtt, len(msrvtt_testset) def save_model(epoch, args, model, type_name=""): # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model output_model_file = os.path.join( args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch)) torch.save(model_to_save.state_dict(), output_model_file) logger.info("Model saved to %s", output_model_file) return output_model_file def load_model(epoch, args, n_gpu, device, model_file=None): if model_file is None or len(model_file) == 0: model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch)) if os.path.exists(model_file): model_state_dict = torch.load(model_file, map_location='cpu') if args.local_rank == 0: logger.info("Model loaded from %s", model_file) # Prepare model cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed') model = UniVL.from_pretrained(args.bert_model, args.visual_model, args.cross_model, args.decoder_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args) model.to(device) else: model = None return model def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=0): global logger torch.cuda.empty_cache() model.train() log_step = args.n_display start_time = time.time() total_loss = 0 for step, batch in enumerate(train_dataloader): if n_gpu == 1: # multi-gpu does scattering it-self batch = tuple(t.to(device=device, non_blocking=True) for t in batch) input_ids, input_mask, segment_ids, video, video_mask, \ pairs_masked_text, pairs_token_labels, masked_video, video_labels_index = batch loss = model(input_ids, segment_ids, input_mask, video, video_mask, pairs_masked_text=pairs_masked_text, pairs_token_labels=pairs_token_labels, masked_video=masked_video, video_labels_index=video_labels_index) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() total_loss += float(loss) if (step + 1) % args.gradient_accumulation_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) if scheduler is not None: scheduler.step() # Update learning rate schedule optimizer.step() optimizer.zero_grad() global_step += 1 if global_step % log_step == 0 and local_rank == 0: logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1, args.epochs, step + 1, len(train_dataloader), "-".join([str('%.6f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]), float(loss) * args.gradient_accumulation_steps, (time.time() - start_time) / (log_step * args.gradient_accumulation_steps)) start_time = time.time() total_loss = total_loss / len(train_dataloader) return total_loss, global_step def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list): sim_matrix = [] for idx1, b1 in enumerate(batch_list_t): input_ids, input_mask, segment_ids, _, _, _, _, _, _ = b1 sequence_output = batch_sequence_output_list[idx1] each_row = [] for idx2, b2 in enumerate(batch_list_v): _, _, _, video, video_mask, _, _, _, _ = b2 visual_output = batch_visual_output_list[idx2] b1b2_logits = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask) b1b2_logits = b1b2_logits.cpu().detach().numpy() each_row.append(b1b2_logits) each_row = np.concatenate(tuple(each_row), axis=-1) sim_matrix.append(each_row) return sim_matrix def eval_epoch(args, model, test_dataloader, device, n_gpu): if hasattr(model, 'module'): model = model.module.to(device) else: model = model.to(device) model.eval() with torch.no_grad(): batch_list = [] batch_sequence_output_list, batch_visual_output_list = [], [] for bid, batch in enumerate(test_dataloader): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, video, video_mask, _, _, _, _ = batch sequence_output, visual_output = model.get_sequence_visual_output(input_ids, segment_ids, input_mask, video, video_mask) batch_sequence_output_list.append(sequence_output) batch_visual_output_list.append(visual_output) batch_list.append(batch) print("{}/{}\r".format(bid, len(test_dataloader)), end="") if n_gpu > 1: device_ids = list(range(n_gpu)) batch_list_t_splits = [] batch_list_v_splits = [] batch_t_output_splits = [] batch_v_output_splits = [] bacth_len = len(batch_list) split_len = (bacth_len + n_gpu - 1) // n_gpu for dev_id in device_ids: s_, e_ = dev_id * split_len, (dev_id + 1) * split_len if dev_id == 0: batch_list_t_splits.append(batch_list[s_:e_]) batch_list_v_splits.append(batch_list) batch_t_output_splits.append(batch_sequence_output_list[s_:e_]) batch_v_output_splits.append(batch_visual_output_list) else: devc = torch.device('cuda:{}'.format(str(dev_id))) devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list[s_:e_]] batch_list_t_splits.append(devc_batch_list) devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list] batch_list_v_splits.append(devc_batch_list) devc_batch_list = [b.to(devc) for b in batch_sequence_output_list[s_:e_]] batch_t_output_splits.append(devc_batch_list) devc_batch_list = [b.to(devc) for b in batch_visual_output_list] batch_v_output_splits.append(devc_batch_list) parameters_tuple_list = [(batch_list_t_splits[dev_id], batch_list_v_splits[dev_id], batch_t_output_splits[dev_id], batch_v_output_splits[dev_id]) for dev_id in device_ids] parallel_outputs = parallel_apply(_run_on_single_gpu, model, parameters_tuple_list, device_ids) sim_matrix = [] for idx in range(len(parallel_outputs)): sim_matrix += parallel_outputs[idx] sim_matrix = np.concatenate(tuple(sim_matrix), axis=0) else: sim_matrix = _run_on_single_gpu(model, batch_list, batch_list, batch_sequence_output_list, batch_visual_output_list) metrics = compute_metrics(sim_matrix) logger.info('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0]))) logger.info('\t>>> R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'. format(metrics['R1'], metrics['R5'], metrics['R10'], metrics['MR'])) R1 = metrics['R1'] return R1 DATALOADER_DICT = {} DATALOADER_DICT["youcook"] = {"train":dataloader_youcook_train, "val":dataloader_youcook_test} DATALOADER_DICT["msrvtt"] = {"train":dataloader_msrvtt_train, "val":dataloader_msrvtt_test} def main(): global logger args = get_args() args = set_seed_logger(args) device, n_gpu = init_device(args, args.local_rank) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) assert args.task_type == "retrieval" model = init_model(args, device, n_gpu, args.local_rank) assert args.datatype in DATALOADER_DICT test_dataloader, test_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer) if args.local_rank == 0: logger.info("***** Running test *****") logger.info(" Num examples = %d", test_length) logger.info(" Batch size = %d", args.batch_size_val) logger.info(" Num steps = %d", len(test_dataloader)) if args.do_train: train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer) num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1) / args.gradient_accumulation_steps) * args.epochs coef_lr = args.coef_lr if args.init_model: coef_lr = 1.0 optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr) if args.local_rank == 0: logger.info("***** Running training *****") logger.info(" Num examples = %d", train_length) logger.info(" Batch size = %d", args.batch_size) logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps) best_score = 0.00001 best_output_model_file = None global_step = 0 for epoch in range(args.epochs): train_sampler.set_epoch(epoch) tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=args.local_rank) if args.local_rank == 0: logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss) output_model_file = save_model(epoch, args, model, type_name="") R1 = eval_epoch(args, model, test_dataloader, device, n_gpu) if best_score <= R1: best_score = R1 best_output_model_file = output_model_file logger.info("The best model is: {}, the R1 is: {:.4f}".format(best_output_model_file, best_score)) if args.local_rank == 0: model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file) eval_epoch(args, model, test_dataloader, device, n_gpu) elif args.do_eval: if args.local_rank == 0: eval_epoch(args, model, test_dataloader, device, n_gpu) if __name__ == "__main__": main()