re organize
This commit is contained in:
Родитель
0a1b59cb95
Коммит
b9da891396
|
@ -0,0 +1,230 @@
|
|||
import argparse
|
||||
import os
|
||||
from transformers import set_seed
|
||||
from diffusion_util.resample import create_named_schedule_sampler
|
||||
from transformers import AutoTokenizer
|
||||
import json
|
||||
from util import logger
|
||||
from train_util import dist_util
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from util.util import (
|
||||
create_model_and_diffusion,
|
||||
args_to_dict,
|
||||
)
|
||||
import collections
|
||||
from data_util.s2s_data_util import load_s2s_data
|
||||
from train_util.train_util import TrainLoop
|
||||
from torch.serialization import default_restore_location
|
||||
|
||||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
CheckpointState = collections.namedtuple("CheckpointState",
|
||||
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset'])
|
||||
|
||||
|
||||
def get_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# out path
|
||||
parser.add_argument('--checkpoint_path', type=str, default='', help='output path')
|
||||
|
||||
# load pretrain
|
||||
parser.add_argument('--pretrain_model_path', type=str, default=None, help='using pretraining diffusion')
|
||||
|
||||
# load model
|
||||
parser.add_argument('--model_arch', type=str, default='transformer', help='Core architecture of diffusion model')
|
||||
parser.add_argument('--model_channels', type=int, default=768, help='Try to set it to the same size as the model hidden')
|
||||
parser.add_argument('--in_channel', type=int, default=768, help='The input chanel size here must be the same as the word embedding size')
|
||||
parser.add_argument('--out_channel', type=int, default=768, help='The dimension size of the output is recommended to be the same as that of word embedding for easy reasoning')
|
||||
parser.add_argument('--dropout', type=float, default=0.1, help='')
|
||||
parser.add_argument("--learn_sigma", default=False, action="store_true", help="Whether to learning variance")
|
||||
parser.add_argument('--logits_mode', type=int, default=1, help='final logits mode of Diffusion model')
|
||||
parser.add_argument('--vocab_size', type=int, default=30522, help='vocab size')
|
||||
parser.add_argument('--config_name', type=str, default='bert-base-uncased', help='')
|
||||
parser.add_argument('--token_emb_type', type=str, default='random', help='token embedding type')
|
||||
parser.add_argument("--init_pretrained", default=False, action="store_true", help="Whether to using pretrain BERT encoder")
|
||||
parser.add_argument("--fix_encoder", default=False, action="store_true",
|
||||
help="Whether to training encoder")
|
||||
|
||||
|
||||
# load diffusion
|
||||
parser.add_argument('--diffusion_steps', type=int, default=2000, help='Diffusion model maximum T')
|
||||
parser.add_argument('--use_kl', default=False, action="store_true", help="Whether to using kl loss in Diffsion loss")
|
||||
parser.add_argument('--training_mode', type=str, default='e2e', help='using e2e simple loss or e2e loss or s2s loss')
|
||||
parser.add_argument('--noise_schedule', type=str, default='sqrt', help='How to plan the noise change of Gaussian distribution')
|
||||
parser.add_argument('--predict_xstart', default=False, action="store_true", help="Model prediction target, if True, predict xstart, if False, predict EPSILON")
|
||||
parser.add_argument("--sigma_small", default=False, action="store_true", help="about learning variance")
|
||||
parser.add_argument("--rescale_learned_sigmas", default=True, action="store_false", help="about learning variance")
|
||||
parser.add_argument("--rescale_timesteps", default=True, action="store_false", help="about time rescale")
|
||||
|
||||
# sample t
|
||||
parser.add_argument('--schedule_sampler', type=str, default='uniform', help='how to sample t per batch, uniform is Uniform sampling, loss-second-moment is Sampling according to loss')
|
||||
|
||||
# data args
|
||||
parser.add_argument('--data_path', type=str, default='',help='data path')
|
||||
parser.add_argument('--data_name', type=str, default='', help='data name')
|
||||
# for seq2seq
|
||||
parser.add_argument('--src_max_len', type=int, default=144, help='src max len')
|
||||
parser.add_argument('--tgt_max_len', type=int, default=32, help='tgt max len')
|
||||
parser.add_argument('--answer_max_len', type=int, default=10, help='tgt max len')
|
||||
# for doc2query
|
||||
parser.add_argument('--text_max_len', type=int, default=None, help='text max len')
|
||||
parser.add_argument('--pas_max_len', type=int, default=None, help='pas max len')
|
||||
|
||||
# training args
|
||||
parser.add_argument('--train_type', type=str, default='LM_Diffusion', help='LM_Diffusion or S2S_Diffusion')
|
||||
parser.add_argument('--lr_anneal_steps', type=int, default=200000, help='total step')
|
||||
parser.add_argument('--batch_size', type=int, default=64, help='')
|
||||
parser.add_argument('--lr', type=float, default=1e-04, help='')
|
||||
parser.add_argument('--warmup_steps', type=int, default=20000, help='')
|
||||
parser.add_argument('--ema_rate', type=str, default='0.9999', help='ema training to stable model')
|
||||
parser.add_argument('--resume_checkpoint', type=str, default=None, help='')
|
||||
parser.add_argument('--eval_interval', type=int, default=2000, help='')
|
||||
parser.add_argument('--log_interval', type=int, default=100, help='')
|
||||
parser.add_argument('--save_interval', type=int, default=50000, help='')
|
||||
|
||||
parser.add_argument('--weight_decay', type=str, default=0.0, help='')
|
||||
parser.add_argument('--gradient_clipping', type=float, default=-1., help='')
|
||||
parser.add_argument("--use_fp16", default=False, action="store_true", help="about learning variance")
|
||||
parser.add_argument('--fp16_scale_growth', type=float, default=1e-3, help='')
|
||||
|
||||
# seed
|
||||
parser.add_argument('--seed', type=int, default=101, help='')
|
||||
|
||||
# muti-gpu
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def setup_env(args):
|
||||
if args.local_rank == -1:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# store args
|
||||
if args.local_rank != -1:
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
args.rank = dist.get_rank()
|
||||
|
||||
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
|
||||
logger.info('Reading saved model from %s', model_file)
|
||||
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
|
||||
logger.info('model_state_dict keys %s', state_dict.keys())
|
||||
return CheckpointState(**state_dict)
|
||||
|
||||
def main():
|
||||
# args setting
|
||||
args = get_arguments()
|
||||
|
||||
# out dir set
|
||||
if dist.get_rank() == 0:
|
||||
if not os.path.exists(args.checkpoint_path):
|
||||
os.makedirs(args.checkpoint_path)
|
||||
# dist.barrier()
|
||||
|
||||
logger.log(f'saving the hyperparameters to {args.checkpoint_path}/training_args.json')
|
||||
with open(f'{args.checkpoint_path}/training_args.json', 'w') as f:
|
||||
json.dump(args.__dict__, f, indent=2)
|
||||
|
||||
# seed setting
|
||||
set_seed(args.seed)
|
||||
# dpp setting
|
||||
setup_env(args)
|
||||
# dist_util.setup_dist()
|
||||
|
||||
# logger setting
|
||||
log_path = os.path.join(args.checkpoint_path, 'log.txt')
|
||||
logger.configure(dir=log_path)
|
||||
|
||||
model, diffusion = create_model_and_diffusion(
|
||||
args
|
||||
)
|
||||
if args.pretrain_model_path is not None:
|
||||
print("load model ckpt at :", args.pretrain_model_path)
|
||||
saved_state = load_states_from_checkpoint(args.pretrain_model_path)
|
||||
model.load_state_dict(saved_state.model_dict, strict=False)
|
||||
model.to(args.device)
|
||||
|
||||
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
||||
logger.log(f'the parameter count is {pytorch_total_params}')
|
||||
|
||||
'''
|
||||
time step schedule sampler
|
||||
'''
|
||||
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
|
||||
|
||||
'''
|
||||
tokenize
|
||||
'''
|
||||
logger.log("loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
||||
|
||||
'''
|
||||
for s2s
|
||||
'''
|
||||
# load data (train)
|
||||
train_data = load_s2s_data(
|
||||
args,
|
||||
split='train',
|
||||
padding_mode='max_len',
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
# load data (dev)
|
||||
dev_data = load_s2s_data(
|
||||
args,
|
||||
split='dev',
|
||||
padding_mode='max_len',
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
'''
|
||||
training
|
||||
'''
|
||||
logger.log("training Diffusion LM model...")
|
||||
TrainLoop(
|
||||
# training type
|
||||
train_type=args.train_type,
|
||||
# Training Core
|
||||
model=model,
|
||||
diffusion=diffusion,
|
||||
data=train_data,
|
||||
eval_data=dev_data,
|
||||
schedule_sampler=schedule_sampler,
|
||||
checkpoint_path=args.checkpoint_path,
|
||||
# Training Parameters
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
ema_rate=args.ema_rate,
|
||||
weight_decay=args.weight_decay,
|
||||
lr_anneal_steps=args.lr_anneal_steps,
|
||||
gradient_clipping=args.gradient_clipping,
|
||||
# fp16
|
||||
use_fp16=args.use_fp16,
|
||||
fp16_scale_growth=args.fp16_scale_growth,
|
||||
# Training Log
|
||||
resume_checkpoint=args.resume_checkpoint,
|
||||
eval_interval=args.eval_interval,
|
||||
log_interval=args.log_interval,
|
||||
save_interval=args.save_interval,
|
||||
# device
|
||||
device=args.device,
|
||||
# finetune data name
|
||||
data_name=args.data_name
|
||||
).run_loop()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,397 @@
|
|||
import os
|
||||
from util import logger
|
||||
from train_util import dist_util
|
||||
from util.util import (
|
||||
create_model_and_diffusion,
|
||||
args_to_dict,
|
||||
)
|
||||
# from transformers import set_seed
|
||||
import torch
|
||||
import collections
|
||||
import argparse
|
||||
from transformers import AutoTokenizer
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from data_util.s2s_data_util import load_s2s_data
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
||||
from torch.utils.data import DataLoader
|
||||
from data_util.s2s_data_util import S2S_dataset, QG_dataset_Diff
|
||||
from torch.serialization import default_restore_location
|
||||
|
||||
|
||||
from transformers import (
|
||||
BertModel,
|
||||
BertConfig,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from data_util.text_data_util import load_data_text
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
|
||||
def get_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# out path
|
||||
parser.add_argument('--generate_path', type=str, default='', help='output path')
|
||||
parser.add_argument('--eval_model_path', type=str, default='', help='model path')
|
||||
parser.add_argument('--num_samples', type=int, default=50, help='sample query')
|
||||
parser.add_argument('--interval_step', type=int, default=1, help='inference t interval step')
|
||||
|
||||
# load model
|
||||
parser.add_argument('--model_arch', type=str, default='transformer', help='Core architecture of diffusion model')
|
||||
parser.add_argument('--model_channels', type=int, default=768,
|
||||
help='Try to set it to the same size as the model hidden')
|
||||
parser.add_argument('--in_channel', type=int, default=768,
|
||||
help='The input chanel size here must be the same as the word embedding size')
|
||||
parser.add_argument('--out_channel', type=int, default=768,
|
||||
help='The dimension size of the output is recommended to be the same as that of word embedding for easy reasoning')
|
||||
parser.add_argument('--dropout', type=float, default=0.1, help='')
|
||||
parser.add_argument("--learn_sigma", default=False, action="store_true", help="Whether to learning variance")
|
||||
parser.add_argument('--logits_mode', type=int, default=1, help='final logits mode of Diffusion model')
|
||||
parser.add_argument('--vocab_size', type=int, default=30522, help='vocab size')
|
||||
parser.add_argument('--config_name', type=str, default='bert-base-uncased', help='')
|
||||
parser.add_argument('--token_emb_type', type=str, default='random', help='token embedding type')
|
||||
parser.add_argument("--init_pretrained", default=False, action="store_true",
|
||||
help="Whether to using pretrain BERT encoder")
|
||||
|
||||
# load diffusion
|
||||
# parser.add_argument('--model_arch', type=str, default='transformer', help='Core architecture of diffusion model')
|
||||
parser.add_argument('--diffusion_steps', type=int, default=2000, help='Diffusion model maximum T')
|
||||
# parser.add_argument("--learn_sigma", default=False, action="store_true", help="Whether to learning variance")
|
||||
parser.add_argument('--use_kl', default=False, action="store_true",
|
||||
help="Whether to using kl loss in Diffsion loss")
|
||||
parser.add_argument('--training_mode', type=str, default='e2e', help='using e2e simple loss or e2e loss')
|
||||
parser.add_argument('--noise_schedule', type=str, default='sqrt',
|
||||
help='How to plan the noise change of Gaussian distribution')
|
||||
parser.add_argument('--predict_xstart', default=False, action="store_true",
|
||||
help="Model prediction target, if True, predict xstart, if False, predict EPSILON")
|
||||
parser.add_argument("--sigma_small", default=False, action="store_true", help="about learning variance")
|
||||
parser.add_argument("--rescale_learned_sigmas", default=True, action="store_false", help="about learning variance")
|
||||
parser.add_argument("--rescale_timesteps", default=True, action="store_false", help="about time rescale")
|
||||
|
||||
# data args
|
||||
parser.add_argument('--data_path', type=str, default='', help='data path')
|
||||
parser.add_argument('--data_name', type=str, default='', help='data name')
|
||||
# for seq2seq
|
||||
parser.add_argument('--src_max_len', type=int, default=144, help='src max len')
|
||||
parser.add_argument('--tgt_max_len', type=int, default=32, help='tgt max len')
|
||||
parser.add_argument('--answer_max_len', type=int, default=10, help='tgt max len')
|
||||
|
||||
# gen args
|
||||
parser.add_argument('--batch_size', type=int, default=64, help='')
|
||||
|
||||
# seed
|
||||
parser.add_argument('--seed', type=int, default=101, help='')
|
||||
#
|
||||
# muti-gpu
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
CheckpointState = collections.namedtuple("CheckpointState",
|
||||
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset'])
|
||||
|
||||
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
|
||||
logger.info('Reading saved model from %s', model_file)
|
||||
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
|
||||
logger.info('model_state_dict keys %s', state_dict.keys())
|
||||
return CheckpointState(**state_dict)
|
||||
|
||||
|
||||
'''
|
||||
rounding
|
||||
'''
|
||||
def denoised_fn_round(args, model, text_emb, t):
|
||||
# thresh_t = 50
|
||||
# # print(thresh_t)
|
||||
# if thresh_t is not None and t[0] > thresh_t:
|
||||
# return text_emb
|
||||
|
||||
if args.model_arch == '1d-unet':
|
||||
text_emb = text_emb.permute(0, 2, 1)
|
||||
# return text_emb
|
||||
# print(t.float().mean(), t[0])
|
||||
|
||||
# assert t.float().mean() == t[0].float()
|
||||
|
||||
# print(text_emb.shape) # bsz, seqlen, dim
|
||||
down_proj_emb = model.weight # input_embs
|
||||
# print(t)
|
||||
old_shape = text_emb.shape
|
||||
old_device = text_emb.device
|
||||
|
||||
def get_efficient_knn(down_proj_emb, text_emb, dist='l2'):
|
||||
if dist == 'l2':
|
||||
emb_norm = (down_proj_emb ** 2).sum(-1).view(-1, 1) # vocab
|
||||
text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
|
||||
arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
|
||||
# print(emb_norm.shape, arr_norm.shape)
|
||||
dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(down_proj_emb,
|
||||
text_emb_t) # (vocab, d) x (d, bsz*seqlen)
|
||||
dist = torch.clamp(dist, 0.0, np.inf)
|
||||
# print(dist.shape)
|
||||
topk_out = torch.topk(-dist, k=1, dim=0)
|
||||
# adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
|
||||
# down_proj_emb.size(0), -1, -1)
|
||||
# adjacency = -th.norm(adjacency, dim=-1)
|
||||
# topk_out = th.topk(adjacency, k=1, dim=0)
|
||||
# print(topk_out1.indices == topk_out.indices)
|
||||
# assert th.all(topk_out1.indices == topk_out.indices)
|
||||
return topk_out.values, topk_out.indices
|
||||
|
||||
def get_knn(down_proj_emb, text_emb, dist='l2'):
|
||||
if dist == 'l2':
|
||||
adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
|
||||
down_proj_emb.size(0), -1, -1)
|
||||
adjacency = -torch.norm(adjacency, dim=-1)
|
||||
topk_out = torch.topk(adjacency, k=1, dim=0)
|
||||
return topk_out.values, topk_out.indices
|
||||
|
||||
dist = 'l2'
|
||||
if len(text_emb.shape) > 2:
|
||||
text_emb = text_emb.reshape(-1, text_emb.size(-1))
|
||||
else:
|
||||
text_emb = text_emb
|
||||
# val, indices = get_knn(down_proj_emb,
|
||||
# text_emb.to(down_proj_emb.device), dist=dist)
|
||||
val, indices = get_efficient_knn(down_proj_emb,
|
||||
text_emb.to(down_proj_emb.device), dist=dist)
|
||||
rounded_tokens = indices[0]
|
||||
# print(rounded_tokens.shape)
|
||||
new_embeds = model(rounded_tokens).view(old_shape).to(old_device)
|
||||
if args.model_arch == '1d-unet':
|
||||
new_embeds = new_embeds.permute(0, 2, 1)
|
||||
return new_embeds
|
||||
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
def setup_env(args):
|
||||
if args.local_rank == -1:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# store args
|
||||
if args.local_rank != -1:
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
args.rank = dist.get_rank()
|
||||
|
||||
def main():
|
||||
# env setting
|
||||
args = get_arguments()
|
||||
# setup_seed(args.seed)
|
||||
setup_env(args)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
if not os.path.exists(args.generate_path):
|
||||
os.makedirs(args.generate_path)
|
||||
|
||||
log_path = os.path.join(args.generate_path, 'log')
|
||||
logger.configure(dir=log_path)
|
||||
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# define model and diffusion
|
||||
model, diffusion = create_model_and_diffusion(
|
||||
args
|
||||
)
|
||||
model.to(args.device)
|
||||
model.eval()
|
||||
# load trained model
|
||||
model_saved_state = load_states_from_checkpoint(args.eval_model_path)
|
||||
model.load_state_dict(model_saved_state.model_dict)
|
||||
|
||||
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
||||
logger.log(f'the parameter count is {pytorch_total_params}')
|
||||
|
||||
if dist.get_world_size() > 1:
|
||||
model = DDP(
|
||||
model, device_ids=[dist.get_rank()], output_device=dist.get_rank(), find_unused_parameters=False,
|
||||
)
|
||||
|
||||
logger.log("sampling text from random noise...")
|
||||
print("sample num is :", args.num_samples)
|
||||
print("sample interval step is :", args.interval_step)
|
||||
print("total inverse diffusion step is :", 2000 // args.interval_step)
|
||||
|
||||
sample_fn = (
|
||||
diffusion.p_sample_loop
|
||||
)
|
||||
|
||||
if dist.get_world_size() > 1:
|
||||
emb_model = model.module.word_embedding
|
||||
else:
|
||||
emb_model = model.word_embedding
|
||||
|
||||
if args.model_arch == 'transformer':
|
||||
|
||||
sample_shape = (args.num_samples, args.text_max_len, args.in_channel)
|
||||
sample = sample_fn(
|
||||
model,
|
||||
sample_shape,
|
||||
clip_denoised=False,
|
||||
denoised_fn=partial(denoised_fn_round, args, emb_model.cuda()),
|
||||
model_kwargs=None,
|
||||
top_p=-1.0,
|
||||
)
|
||||
print("sample result shape: ", sample.shape)
|
||||
print('decoding for e2e... ')
|
||||
|
||||
logits = model.get_logits(sample)
|
||||
cands = torch.topk(logits, k=1, dim=-1)
|
||||
sample_id_list = cands.indices
|
||||
print("decode id list example :", type(sample_id_list[0]), " ", sample_id_list[0])
|
||||
|
||||
logger.log("creating tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
||||
for sample_id in sample_id_list:
|
||||
sentence = tokenizer.decode(sample_id.squeeze())
|
||||
print(sentence)
|
||||
|
||||
elif args.model_arch == 's2s_CAT':
|
||||
|
||||
# bert tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
||||
print("-------------------------------------------------------------")
|
||||
print("start generate query from dev dataset, for every passage, we generate ", args.num_samples, " querys...")
|
||||
print("-------------------------------------------------------------")
|
||||
|
||||
print("***** load " + args.data_name + " test src dataset*****")
|
||||
src = []
|
||||
test_src_path = os.path.join(args.data_path, args.data_name + "/org_data/test.src")
|
||||
with open(test_src_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
src.append(text)
|
||||
|
||||
print("***** load " + args.data_name + " dev tgt dataset*****")
|
||||
tgt = []
|
||||
test_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/test.tgt")
|
||||
with open(test_tgt_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
tgt.append(text)
|
||||
|
||||
shard_size = len(src) // args.world_size
|
||||
start_idx = args.local_rank * shard_size
|
||||
end_idx = start_idx + shard_size
|
||||
if args.local_rank == args.world_size - 1:
|
||||
end_idx = len(src)
|
||||
scr_data_piece = src[start_idx:end_idx]
|
||||
tgt_data_piece = tgt[start_idx:end_idx]
|
||||
|
||||
print('generation for ', len(scr_data_piece), " src text from idx ", start_idx, " to ", end_idx)
|
||||
if args.data_name == "squadqg_data":
|
||||
test_dataset = QG_dataset_Diff(scr_data_piece, tgt_data_piece, tokenizer, src_maxlength=args.src_max_len,
|
||||
answer_maxlength=args.answer_max_len, tgt_maxlength=args.tgt_max_len)
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, drop_last=False,
|
||||
num_workers=20, collate_fn=QG_dataset_Diff.get_collate_fn())
|
||||
else:
|
||||
test_dataset = S2S_dataset(scr_data_piece, tgt_data_piece, tokenizer, src_maxlength=args.src_max_len,
|
||||
tgt_maxlength=args.tgt_max_len)
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, drop_last=False,
|
||||
num_workers=20, collate_fn=S2S_dataset.get_collate_fn())
|
||||
|
||||
if args.generate_path is not None:
|
||||
model_gen_files = []
|
||||
if os.path.exists(args.generate_path):
|
||||
for item in os.scandir(args.generate_path):
|
||||
if item.is_file():
|
||||
if "gen_seed" in item.path:
|
||||
model_gen_files.append(item.path)
|
||||
if len(model_gen_files) != 0 :
|
||||
model_gen_files.sort(key=lambda f: int((f.split('_epoch')[-1]).split('.txt')[0]), reverse=True)
|
||||
epoch_num = int((model_gen_files[0].split('_epoch')[-1]).split('.txt')[0])
|
||||
logger.info("***** load " + model_gen_files[0] + " *****")
|
||||
else:
|
||||
epoch_num = 0
|
||||
|
||||
else:
|
||||
logger.info("generate_path is None")
|
||||
exit(0)
|
||||
|
||||
for epoch in range(args.num_samples - epoch_num):
|
||||
each_sample_list = []
|
||||
print("-------------------------------------------------------------")
|
||||
print("start sample ", epoch+1+epoch_num, " epoch...")
|
||||
print("-------------------------------------------------------------")
|
||||
|
||||
for index, batch in enumerate(tqdm(test_dataloader)):
|
||||
'''
|
||||
for s2s
|
||||
'''
|
||||
input_shape = (batch['src_input_ids'].shape[0], args.tgt_max_len, args.in_channel)
|
||||
src_input_ids = batch['src_input_ids']
|
||||
tgt_input_ids = batch['tgt_input_ids']
|
||||
# print(p_input_ids.shape)
|
||||
src_attention_mask = batch['src_attention_mask']
|
||||
model_kwargs = {'src_input_ids' : src_input_ids, 'src_attention_mask': src_attention_mask}
|
||||
|
||||
sample = sample_fn(
|
||||
model,
|
||||
input_shape,
|
||||
clip_denoised=False,
|
||||
denoised_fn=partial(denoised_fn_round, args, emb_model.cuda()),
|
||||
model_kwargs=model_kwargs,
|
||||
top_p=-1.0,
|
||||
interval_step=args.interval_step,
|
||||
)
|
||||
|
||||
print("sample result shape: ", sample.shape)
|
||||
print('decoding for e2e... ')
|
||||
|
||||
logits = model.module.get_logits(sample)
|
||||
cands = torch.topk(logits, k=1, dim=-1)
|
||||
sample_id_list = cands.indices
|
||||
#print("decode id list example :", type(sample_id_list[0]), " ", sample_id_list[0])
|
||||
|
||||
'''
|
||||
for s2s
|
||||
'''
|
||||
# print("src text: ", tokenizer.decode(src_input_ids.squeeze()))
|
||||
# print("tgt text: ", tokenizer.decode(tgt_input_ids.squeeze()))
|
||||
|
||||
print("sample control generate query: ")
|
||||
for sample_id in sample_id_list:
|
||||
sentence = tokenizer.decode(sample_id.squeeze())
|
||||
each_sample_list.append(clean(sentence))
|
||||
# print(sentence)
|
||||
|
||||
# total_sample_list.append(each_sample_list)
|
||||
out_path = os.path.join(args.generate_path, "rank" + str(dist.get_rank()) + "_gen_seed_101" +
|
||||
"_num" + str(args.num_samples) + "_epoch" + str(epoch + 1 + epoch_num) + ".txt")
|
||||
with open(out_path, 'w') as f:
|
||||
for sentence in each_sample_list:
|
||||
f.write(sentence + '\n')
|
||||
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
def clean(sentence):
|
||||
sentence = sentence.replace('[CLS]', '')
|
||||
sentence = sentence.replace('[SEP]', '')
|
||||
sentence = sentence.replace('[PAD]', '')
|
||||
sentence = sentence.replace('[UNK]', 'unk')
|
||||
return sentence.strip()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,220 @@
|
|||
import argparse
|
||||
import os
|
||||
from transformers import set_seed
|
||||
from diffusion_util.resample import create_named_schedule_sampler
|
||||
from transformers import AutoTokenizer
|
||||
import json
|
||||
from util import logger
|
||||
from train_util import dist_util
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from util.util import (
|
||||
create_model_and_diffusion,
|
||||
args_to_dict,
|
||||
)
|
||||
from data_util.pretrain_data_util import load_pretrain_data
|
||||
from torch.serialization import default_restore_location
|
||||
from train_util.pretrain_util import PretrainLoop
|
||||
import collections
|
||||
|
||||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
CheckpointState = collections.namedtuple("CheckpointState",
|
||||
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset'])
|
||||
|
||||
|
||||
def get_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# out path
|
||||
parser.add_argument('--checkpoint_path', type=str, default='', help='output path')
|
||||
parser.add_argument('--pretrain_model_path', type=str, default=None, help='continue train')
|
||||
|
||||
|
||||
# load model
|
||||
parser.add_argument('--model_arch', type=str, default='transformer', help='Core architecture of diffusion model')
|
||||
parser.add_argument('--model_channels', type=int, default=768, help='Try to set it to the same size as the model hidden')
|
||||
parser.add_argument('--in_channel', type=int, default=768, help='The input chanel size here must be the same as the word embedding size')
|
||||
parser.add_argument('--out_channel', type=int, default=768, help='The dimension size of the output is recommended to be the same as that of word embedding for easy reasoning')
|
||||
parser.add_argument('--dropout', type=float, default=0.1, help='')
|
||||
parser.add_argument("--learn_sigma", default=False, action="store_true", help="Whether to learning variance")
|
||||
parser.add_argument('--logits_mode', type=int, default=1, help='final logits mode of Diffusion model')
|
||||
parser.add_argument('--vocab_size', type=int, default=30522, help='vocab size')
|
||||
parser.add_argument('--config_name', type=str, default='bert-base-uncased', help='')
|
||||
parser.add_argument('--token_emb_type', type=str, default='random', help='token embedding type')
|
||||
parser.add_argument("--init_pretrained", default=False, action="store_true", help="Whether to using pretrain BERT encoder")
|
||||
parser.add_argument("--fix_encoder", default=False, action="store_true",
|
||||
help="Whether to training encoder")
|
||||
|
||||
|
||||
# load diffusion
|
||||
parser.add_argument('--diffusion_steps', type=int, default=2000, help='Diffusion model maximum T')
|
||||
parser.add_argument('--use_kl', default=False, action="store_true", help="Whether to using kl loss in Diffsion loss")
|
||||
parser.add_argument('--training_mode', type=str, default='e2e', help='using e2e simple loss or e2e loss or s2s loss')
|
||||
parser.add_argument('--noise_schedule', type=str, default='sqrt', help='How to plan the noise change of Gaussian distribution')
|
||||
parser.add_argument('--predict_xstart', default=False, action="store_true", help="Model prediction target, if True, predict xstart, if False, predict EPSILON")
|
||||
parser.add_argument("--sigma_small", default=False, action="store_true", help="about learning variance")
|
||||
parser.add_argument("--rescale_learned_sigmas", default=True, action="store_false", help="about learning variance")
|
||||
parser.add_argument("--rescale_timesteps", default=True, action="store_false", help="about time rescale")
|
||||
|
||||
# sample t
|
||||
parser.add_argument('--schedule_sampler', type=str, default='uniform', help='how to sample t per batch, uniform is Uniform sampling, loss-second-moment is Sampling according to loss')
|
||||
|
||||
# data args
|
||||
parser.add_argument('--data_path', type=str, default='',help='data path')
|
||||
parser.add_argument('--data_name', type=str, default='', help='data name')
|
||||
# for retrain
|
||||
parser.add_argument('--pre_max_len', type=int, default=512, help='src max len')
|
||||
parser.add_argument('--mask_pro', type=float, default=0.3, help='mask pro')
|
||||
|
||||
|
||||
# training args
|
||||
parser.add_argument('--train_type', type=str, default='LM_Diffusion', help='LM_Diffusion or S2S_Diffusion')
|
||||
parser.add_argument('--lr_anneal_steps', type=int, default=200000, help='total step')
|
||||
parser.add_argument('--batch_size', type=int, default=64, help='')
|
||||
parser.add_argument('--lr', type=float, default=1e-04, help='')
|
||||
parser.add_argument('--warmup_steps', type=int, default=20000, help='')
|
||||
parser.add_argument('--ema_rate', type=str, default='0.9999', help='ema training to stable model')
|
||||
parser.add_argument('--resume_checkpoint', type=str, default=None, help='')
|
||||
parser.add_argument('--eval_interval', type=int, default=2000, help='')
|
||||
parser.add_argument('--log_interval', type=int, default=100, help='')
|
||||
parser.add_argument('--save_interval', type=int, default=50000, help='')
|
||||
parser.add_argument('--gradient_accumulation_steps', type=int, default=8, help='')
|
||||
|
||||
parser.add_argument('--weight_decay', type=str, default=0.0, help='')
|
||||
parser.add_argument('--gradient_clipping', type=float, default=-1., help='')
|
||||
parser.add_argument("--use_fp16", default=False, action="store_true", help="about learning variance")
|
||||
parser.add_argument('--fp16_scale_growth', type=float, default=1e-3, help='')
|
||||
|
||||
# seed
|
||||
parser.add_argument('--seed', type=int, default=101, help='')
|
||||
|
||||
# muti-gpu
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def setup_env(args):
|
||||
if args.local_rank == -1:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# store args
|
||||
if args.local_rank != -1:
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
args.rank = dist.get_rank()
|
||||
|
||||
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
|
||||
logger.info('Reading saved model from %s', model_file)
|
||||
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
|
||||
logger.info('model_state_dict keys %s', state_dict.keys())
|
||||
return CheckpointState(**state_dict)
|
||||
|
||||
def main():
|
||||
# args setting
|
||||
args = get_arguments()
|
||||
|
||||
# out dir set
|
||||
if dist.get_rank() == 0:
|
||||
if not os.path.exists(args.checkpoint_path):
|
||||
os.makedirs(args.checkpoint_path)
|
||||
# dist.barrier()
|
||||
|
||||
logger.log(f'saving the hyperparameters to {args.checkpoint_path}/training_args.json')
|
||||
with open(f'{args.checkpoint_path}/training_args.json', 'w') as f:
|
||||
json.dump(args.__dict__, f, indent=2)
|
||||
|
||||
# seed setting
|
||||
# set_seed(args.seed)
|
||||
# dpp setting
|
||||
setup_env(args)
|
||||
|
||||
# logger setting
|
||||
log_path = os.path.join(args.checkpoint_path, 'log.txt')
|
||||
logger.configure(dir=log_path)
|
||||
|
||||
model, diffusion = create_model_and_diffusion(
|
||||
args
|
||||
)
|
||||
# load pretrain model to continue train
|
||||
if args.pretrain_model_path is not None:
|
||||
print("load model ckpt at :", args.pretrain_model_path)
|
||||
saved_state = load_states_from_checkpoint(args.pretrain_model_path)
|
||||
model.load_state_dict(saved_state.model_dict, strict=False)
|
||||
model.to(args.device)
|
||||
|
||||
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
||||
logger.log(f'the parameter count is {pytorch_total_params}')
|
||||
|
||||
'''
|
||||
time step schedule sampler
|
||||
'''
|
||||
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
|
||||
|
||||
logger.log("load tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
|
||||
|
||||
'''
|
||||
for s2s
|
||||
'''
|
||||
dataname_list = ['book1','book2','book3','book4','book5',
|
||||
'wiki1', 'wiki2', 'wiki3', 'wiki4', 'wiki5',
|
||||
'stories1','stories2','stories3','stories4','stories5',
|
||||
'openweb1','openweb2','openweb3','openweb4','openweb5',
|
||||
'realnews1', 'realnews2', 'realnews3', 'realnews4', 'realnews5',
|
||||
'realnews6', 'realnews7', 'realnews8', 'realnews9','realnews10']
|
||||
|
||||
# roll data list
|
||||
start_index = 0
|
||||
|
||||
'''
|
||||
training
|
||||
'''
|
||||
logger.log("pretraining Diffusion LM model using ")
|
||||
PretrainLoop(
|
||||
# training type
|
||||
train_type=args.train_type,
|
||||
# Training Core
|
||||
model=model,
|
||||
diffusion=diffusion,
|
||||
data=dataname_list,
|
||||
eval_data=None,
|
||||
schedule_sampler=schedule_sampler,
|
||||
checkpoint_path=args.checkpoint_path,
|
||||
# Training Parameters
|
||||
batch_size=args.batch_size,
|
||||
lr=args.lr,
|
||||
ema_rate=args.ema_rate,
|
||||
weight_decay=args.weight_decay,
|
||||
lr_anneal_steps=args.lr_anneal_steps,
|
||||
gradient_clipping=args.gradient_clipping,
|
||||
# fp16
|
||||
use_fp16=args.use_fp16,
|
||||
fp16_scale_growth=args.fp16_scale_growth,
|
||||
# Training Log
|
||||
resume_checkpoint=args.resume_checkpoint,
|
||||
eval_interval=args.eval_interval,
|
||||
log_interval=args.log_interval,
|
||||
save_interval=args.save_interval,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
# device
|
||||
device=args.device,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
).run_loop()
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,160 @@
|
|||
# GENIE
|
||||
|
||||
This repo provides the code and models for [Text Generation with Diffusion Language Models: A Pre-training Approach with Continuous Paragraph Denoise](https://arxiv.org/abs/2212.11685).
|
||||
|
||||
## 🚀 Overview
|
||||
|
||||
In this paper, we introduce a novel d**I**ffusion language mod**E**l pre-training framework for text **GEN**eration, which we call **GENIE**. GENIE is a large-scale pretrained diffusion language model that consists of an encoder and a diffusion-based decoder, which can generate text by gradually transforming a random noise sequence into a coherent text sequence.
|
||||
|
||||
<div align=center><img src="image\GENIE.png" width = "600" height = 300/></div>
|
||||
|
||||
To pre-train GENIE on a large-scale language corpus, we design a novel pre-training method called *continuous paragraph denoise* (CPD), which encourages the diffusion-decoder to reconstruct a clean text paragraph from a corrupted version, while preserving the semantic and syntactic coherence.
|
||||
|
||||
You can find more details in the [paper](https://arxiv.org/abs/2212.11685).
|
||||
|
||||
## ⚙️ Experiment Preparation
|
||||
|
||||
**Dependencies: **
|
||||
|
||||
- python>=3.6
|
||||
- torch>=1.7.1
|
||||
- datasets>=1.12.1
|
||||
- transformers>=4.9.2 (Huggingface)
|
||||
- pyrouge==0.1.3
|
||||
|
||||
**Downstream Task Dataset:**
|
||||
|
||||
The text generation benchmarks we use is well-known and widely used, including *XSum*, *CNN/DailyMail*, and *GigaWord*. You can find more detailed information and obtain methods of the dataset [here](https://microsoft.github.io/glge/).
|
||||
|
||||
**Model**
|
||||
|
||||
We have released the checkpoint of the GENIE after pre-training on `160G` corpus (6-layer encoder, and 6-layer decoder):
|
||||
|
||||
- **GENIE V1** [[link](https://drive.google.com/file/d/1-AZssEmgs0QdTp_w8-_4cPi0cV-Hot4N/view?usp=share_link)]
|
||||
|
||||
You can also quickly get the GENIE checkpoints fine-tuned on the *XSum*, *CNN/DailyMail*, and *GigaWord* here:
|
||||
|
||||
- GENIE XSum [[link](https://drive.google.com/file/d/1-3NJwuDbSV00TwYs5FqG5cHvCY10CW0h/view?usp=share_link)]
|
||||
- GENIE CNN/DailyMail [[link](https://drive.google.com/file/d/1-6shROw2TLWPTMLQbESmhQzfI0Z3pAOm/view?usp=share_link)]
|
||||
- GENIE GigaWord [[link](https://drive.google.com/file/d/1-7PoPTX0w4Q_Sh4qrxB1WQId1tBCydY-/view?usp=share_link)]
|
||||
|
||||
We will continue to update and optimize this repo in the future.
|
||||
|
||||
## 💡 Pre-training
|
||||
|
||||
In the pre-training process, we use pre-training data consisting of `160Gb` of news, books, stories, and web text. We trained **GENIE V1** on `8 * 40G` `A100` for `50 days`. If you are interested in our pre-training process, please refer to `Genie_Pretrain.py`. Here we provide the pre-training running script for reference:
|
||||
|
||||
```shell
|
||||
OUT_DIR = "/Your/output/path"
|
||||
DATA_PATH = "/Your/pretrain/data/path"
|
||||
|
||||
python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9489 \
|
||||
./GENIE_main/Genie_Pretrain.py \
|
||||
--checkpoint_path=$OUT_DIR \
|
||||
--model_channels 128 --in_channel 128 --out_channel 128 --vocab_size 30522 \
|
||||
--config_name="bert-base-uncased" --token_emb_type="random" --model_arch="s2s_CAT" \
|
||||
--diffusion_steps 2000 --predict_xstart --noise_schedule="sqrt" --training_mode="s2s" \
|
||||
--schedule_sampler="uniform" --pre_max_len 512 --mask_pro 0.3 --seed 2023 \
|
||||
--data_path=$DATA_PATH \
|
||||
--batch_size 64 --lr 1e-04 --warmup_steps 300000 --train_type="S2S_Diffusion" \
|
||||
--eval_interval 2000 --log_interval 2000 --save_interva 20000
|
||||
```
|
||||
|
||||
The pre-training of diffusion model needs careful parameter adjustment and reasonable training configuration, especially the dimension of vocab size and input/output channel, which is worth our constant exploration.
|
||||
|
||||
## ⚽ Fine-tuning
|
||||
|
||||
In this section, we will use *XSum* dataset as an example to demonstrate the process of GENIE fine-tuning on downstream tasks. The running script for fine-tuning is as follows:
|
||||
|
||||
```shell
|
||||
OUT_DIR = "/Your/output/path"
|
||||
DATA_PATH = "/Your/data/path"
|
||||
DATA_NAME = "xsum_data"
|
||||
PRETRAIN_CKPT_PATH = "/Your/pretrain_ckpt/path"
|
||||
|
||||
|
||||
python -u -m torch.distributed.launch --nproc_per_node=4 --master_port=9421 \
|
||||
./GENIE_main/Genie_Finetune.py \
|
||||
--checkpoint_path=$OUT_DIR \
|
||||
--model_channels 128 --in_channel 128 --out_channel 128 --vocab_size 30522 \
|
||||
--config_name="bert-base-uncased" --token_emb_type="random" --model_arch="s2s_CAT" \
|
||||
--diffusion_steps 2000 --predict_xstart --noise_schedule="sqrt" --training_mode="s2s" \
|
||||
--schedule_sampler="loss-second-moment" --tgt_max_len 64 --src_max_len 512 --data_name=$DATA_NAME \
|
||||
--data_path=$DATA_PATH \
|
||||
--lr_anneal_steps 120000 --batch_size 64 --lr 5e-05 --warmup_steps 7200 --train_type="S2S_Diffusion" \
|
||||
--eval_interval 200 --log_interval 200 --save_interva 20000 \
|
||||
--pretrain_model_path=$PRETRAIN_CKPT_PATH
|
||||
```
|
||||
|
||||
Important parameter setting:
|
||||
|
||||
- `--checkpoint_path`: Location of model checkpoints and log file output after fine-tuning.
|
||||
- `--data_path`: Overall catalog of downstream task datasets.
|
||||
- `--data_name`: Name of downstream task dataset, Make sure your data is in the directory composed of "data_path + data_name", The directory needs to contain data files: `train.src`, `train.tgt`, `dev.src`, `dev.tgt`, `test.src`, `test.tgt`.
|
||||
- `--pretrain_model_path`: GENIE checkpoint path after pre-training.
|
||||
|
||||
If you need to replace the fine-tuning task, you just need to organize the data into the required form according to the standard format, such as *CNN/DailyMail*, and *GigaWord*, change `DATA_NAME` to `cnndm_data` or `gigaword_data`.
|
||||
|
||||
If you need to train from scratch (w/o pre-train), just remove the parameter `--pretrain_model_path`.
|
||||
|
||||
## 💬 Generate
|
||||
|
||||
In this section, we will show how to batch generate text from trained GENIE. We need to sample the Gaussian noise and iteratively denoise it with GENIE to restore the text. The running script for generating is as follows:
|
||||
|
||||
```shell
|
||||
OUT_DIR = "/Your/output/path"
|
||||
MODEL_DIR = "/Your/model/ckpt/path"
|
||||
DATA_PATH = "/Your/data/path"
|
||||
DATA_NAME = "xsum_data"
|
||||
|
||||
python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9498 \
|
||||
./GENIE_main/Genie_Generate.py \
|
||||
--generate_path=$OUT_DIR \
|
||||
--eval_model_path=$MODEL_DIR \
|
||||
--data_path=$DATA_PATH \
|
||||
--model_channels 128 --in_channel 128 --out_channel 128 --vocab_size 30522 \
|
||||
--config_name="bert-base-uncased" --token_emb_type="random" \
|
||||
--diffusion_steps 2000 --predict_xstart --noise_schedule="sqrt" \
|
||||
--num_samples 5 --model_arch="s2s_CAT" --data_name=$DATA_NAME \
|
||||
--training_mode="s2s" --tgt_max_len 64 --src_max_len 512 --batch_size=200 \
|
||||
--interval_step 1 --seed 2023
|
||||
```
|
||||
|
||||
Important parameter setting:
|
||||
|
||||
- `--generate_path`: Output location of generated text.
|
||||
- `--eval_model_path`: Model checkpoint path needed for generation after training.
|
||||
- `--data_path`: Overall catalog of downstream task datasets
|
||||
- `--data_name`: Name of downstream task dataset, Make sure your data is in the directory composed of "data_path + data_name", The directory needs to contain data files: `train.src`, `train.tgt`, `dev.src`, `dev.tgt`, `test.src`, `test.tgt`.
|
||||
- `--num_samples`: The number of Gaussian noise samples per sample. (the number of text generated per sample)
|
||||
- `--interval_step`: Interval steps for denoise, default set to 1.
|
||||
|
||||
You can adjust `--batch_size` and parallel GPUs (`--nproc_per_node`) based on the performance of the device you are using. The name of the resulting text file is formatted as `rank[gpu_id]_gen_seed_[seed]_num[num_samples]_epoch[sample epoch].txt`.
|
||||
|
||||
Ultimately, we need to integrate the generated text, running the script as follows:
|
||||
|
||||
```shell
|
||||
OUT_DIR = "/Your/output/path"
|
||||
DATA_PATH = "/Your/data/path"
|
||||
DATA_NAME = "xsum_data"
|
||||
|
||||
python ./GENIE_main/integration/eval_split.py \
|
||||
--generate_path=$OUT_DIR \
|
||||
--data_path=$DATA_PATH \
|
||||
--num_samples 5 --data_name=$DATA_NAME --n_gpu 8 --seed 2023
|
||||
```
|
||||
|
||||
Note that the above parameter settings need to be consistent with the generated parameter settings, and the optimal results will be saved in the `--generate_path`. If you want to reproduce the results in the paper, please use the GLGE official evaluation method [here](https://github.com/microsoft/ProphetNet/tree/master/GLGE_baselines).
|
||||
|
||||
## 📜 Citation
|
||||
|
||||
Please cite our paper if you use [GENIE](https://arxiv.org/abs/2212.11685) in your work:
|
||||
|
||||
```bibtex
|
||||
@article{lin2022genie,
|
||||
title = {Text Generation with Diffusion Language Models: A Pre-training Approach with Continuous Paragraph Denoise},
|
||||
author = {Zhenghao Lin, Yeyun Gong, Yelong Shen, Tong Wu, Zhihao Fan, Chen Lin, Nan Duan, Weizhu Chen},
|
||||
booktitle = {{arXiv}},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,224 @@
|
|||
from tqdm import tqdm
|
||||
import os
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data.dataset import Dataset
|
||||
import torch
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
def load_loop_pretrain_data(args, padding_mode, tokenizer, data_name = None):
|
||||
print("***** load " + data_name + " train src dataset*****")
|
||||
|
||||
path = os.path.join(args.data_path, data_name + '.npy')
|
||||
input_id_list = np.load(path, allow_pickle=True)
|
||||
|
||||
# filter
|
||||
input_id_list = np.array([input_id for input_id in input_id_list if np.count_nonzero(input_id) >= 30])
|
||||
|
||||
if padding_mode == 'max_len':
|
||||
dataset = Pre_dataset(input_id_list, tokenizer, mask_pro=args.mask_pro, maxlength=args.pre_max_len)
|
||||
elif padding_mode == 'conti_tgt':
|
||||
print("using new pretrain method...")
|
||||
dataset = Pre_dataset_type2(input_id_list, tokenizer, mask_pro=args.mask_pro, maxlength=args.pre_max_len)
|
||||
elif padding_mode == 'block':
|
||||
print("padding block is under realization")
|
||||
pass
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
print("example of src id lists: ", dataset[50][0])
|
||||
print("example of tgt id lists: ", dataset[50][1])
|
||||
print("total query dataset len :", len(dataset))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
|
||||
def load_pretrain_data(args, padding_mode, tokenizer, data_name = None):
|
||||
questions = []
|
||||
print("***** load " + data_name + " train src dataset*****")
|
||||
|
||||
input_id_list = None
|
||||
|
||||
if data_name == "book" or data_name == "openweb" or data_name == "wiki" or data_name == "stories":
|
||||
for i in range(5):
|
||||
path = os.path.join(args.data_path, args.data_name + str(i+1) + '.npy')
|
||||
input_id_list_pre = np.load(path, allow_pickle=True)
|
||||
if i == 0:
|
||||
input_id_list = input_id_list_pre
|
||||
else:
|
||||
input_id_list = np.concatenate((input_id_list, input_id_list_pre), axis=0)
|
||||
# with open(path, "r", encoding="utf-8") as ifile:
|
||||
# for line in tqdm(ifile):
|
||||
# line = line.strip()
|
||||
# text = line
|
||||
# tgt.append(text)
|
||||
|
||||
elif data_name == 'realnews':
|
||||
# for i in range(10):
|
||||
# path = os.path.join(args.data_path, args.data_name + str(i+1) + '.txt')
|
||||
# with open(path, "r", encoding="utf-8") as ifile:
|
||||
# for line in tqdm(ifile):
|
||||
# line = line.strip()
|
||||
# text = line
|
||||
# tgt.append(text)
|
||||
for i in range(10):
|
||||
path = os.path.join(args.data_path, args.data_name + str(i+1) + '.npy')
|
||||
input_id_list_pre = np.load(path, allow_pickle=True)
|
||||
if i == 0:
|
||||
input_id_list = input_id_list_pre
|
||||
else:
|
||||
input_id_list = np.concatenate((input_id_list, input_id_list_pre), axis=0)
|
||||
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
# filter
|
||||
input_id_list = np.array([input_id for input_id in input_id_list if np.count_nonzero(input_id) >= 256])
|
||||
|
||||
|
||||
# print("example of src text: ", src[50])
|
||||
print("example of input id: ", input_id_list[50])
|
||||
|
||||
if padding_mode == 'max_len':
|
||||
dataset = Pre_dataset(input_id_list, tokenizer, mask_pro=args.mask_pro, maxlength=args.pre_max_len)
|
||||
elif padding_mode == 'conti_tgt':
|
||||
print("using new pretrain method...")
|
||||
dataset = Pre_dataset_type2(input_id_list, tokenizer, mask_pro=args.mask_pro, maxlength=args.pre_max_len)
|
||||
elif padding_mode == 'block':
|
||||
print("padding block is under realization")
|
||||
pass
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
print("example of src id lists: ", dataset[50][0])
|
||||
print("example of tgt id lists: ", dataset[50][1])
|
||||
print("total query dataset len :", len(dataset))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
|
||||
class Pre_dataset(Dataset):
|
||||
def __init__(self, tgt_id, tokenizer, mask_pro=0.3, maxlength=512, span_size=8, mask_mode='random'):
|
||||
self.tgt_id = tgt_id
|
||||
self.tokenizer = tokenizer
|
||||
self.maxlength = maxlength
|
||||
self.mask_pro = mask_pro
|
||||
self.span_size = span_size
|
||||
self.mask_token_index = self.tokenizer.mask_token_id
|
||||
self.pad_token_index = self.tokenizer.pad_token_id
|
||||
self.all_special_token = self.tokenizer.all_special_ids
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
tgt_example = self.tgt_id[index]
|
||||
# src_input_ids = tgt_example.tolist()
|
||||
tgt_input_ids = (torch.from_numpy(tgt_example)).long()
|
||||
src_input_ids = tgt_input_ids.clone()
|
||||
id_len = torch.nonzero(src_input_ids).shape[0]
|
||||
mask_span_num = int((id_len * self.mask_pro) // self.span_size) + 1
|
||||
# print("mask_span_num:", mask_span_num)
|
||||
mask_index = torch.randint(0, id_len, (mask_span_num,))
|
||||
# print("mask_index:", mask_index)
|
||||
mask_id_mask = torch.full(src_input_ids.shape, False, dtype=torch.bool)
|
||||
retain_id_mask = torch.full(src_input_ids.shape, True, dtype=torch.bool)
|
||||
mask_id_mask[mask_index] = True
|
||||
|
||||
del_index = mask_index.tolist()
|
||||
for i in mask_index:
|
||||
del_index.extend(list(range(i + 1, i + self.span_size)))
|
||||
del_index = [i for i in del_index if i < id_len]
|
||||
del_index = torch.from_numpy(np.array(list(set(del_index))))
|
||||
# print("del_index", del_index)
|
||||
retain_id_mask[del_index] = False
|
||||
retain_id_mask = retain_id_mask | mask_id_mask
|
||||
src_input_ids[mask_id_mask] = self.mask_token_index
|
||||
src_input_ids = src_input_ids[retain_id_mask].tolist()
|
||||
# print("src_input_ids1:", len(src_input_ids))
|
||||
src_input_ids = src_input_ids + [self.pad_token_index] * (self.maxlength - len(src_input_ids))
|
||||
# print("src_input_ids2:", len(src_input_ids))
|
||||
src_input_ids = torch.from_numpy(np.array(src_input_ids)).long()
|
||||
|
||||
return src_input_ids.unsqueeze(0), tgt_input_ids.unsqueeze(0)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tgt_id)
|
||||
|
||||
@classmethod
|
||||
def get_collate_fn(cls):
|
||||
def fn(features):
|
||||
src_tensor = torch.cat([feature[0] for feature in features])
|
||||
tgt_tensor = torch.cat([feature[1] for feature in features])
|
||||
return { "src_input_ids": src_tensor, "src_attention_mask": (src_tensor != 0).long(),
|
||||
"tgt_input_ids": tgt_tensor, "tgt_attention_mask": (tgt_tensor != 0).long() }
|
||||
|
||||
return fn
|
||||
|
||||
class Pre_dataset_type2(Dataset):
|
||||
def __init__(self, tgt_id, tokenizer, mask_pro=0.3, maxlength=512, mask_mode='random'):
|
||||
self.tgt_id = tgt_id
|
||||
self.tokenizer = tokenizer
|
||||
self.maxlength = maxlength
|
||||
self.mask_pro = mask_pro
|
||||
self.tgtmaxlength = int(maxlength * mask_pro) + 1
|
||||
self.mask_token_index = self.tokenizer.mask_token_id
|
||||
self.pad_token_index = self.tokenizer.pad_token_id
|
||||
self.all_special_token = self.tokenizer.all_special_ids
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
tgt_example = self.tgt_id[index]
|
||||
# src_input_ids = tgt_example.tolist()
|
||||
tgt_input_ids = (torch.from_numpy(tgt_example)).long()
|
||||
src_input_ids = tgt_input_ids.clone()
|
||||
id_len = torch.nonzero(src_input_ids).shape[0]
|
||||
|
||||
# mask_span_num = int((id_len * self.mask_pro) // self.span_size) + 1
|
||||
mask_span_len = int(id_len * self.mask_pro)
|
||||
# print("mask_span_num:", mask_span_num)
|
||||
mask_index = random.randint(0, id_len-mask_span_len-1)
|
||||
|
||||
tgt_input_ids = src_input_ids.tolist()[mask_index:mask_index+mask_span_len]
|
||||
|
||||
src_input_ids[mask_index] = self.mask_token_index
|
||||
|
||||
# print("mask_index:", mask_index)
|
||||
# mask_span_len
|
||||
# mask_id_mask = torch.full(src_input_ids.shape, False, dtype=torch.bool)
|
||||
retain_id_mask = torch.full(src_input_ids.shape, True, dtype=torch.bool)
|
||||
# mask_id_mask[mask_index] = True
|
||||
|
||||
# del_index = mask_index.tolist()
|
||||
del_index = list(range(mask_index + 1, mask_index + mask_span_len))
|
||||
del_index = torch.from_numpy(np.array(del_index))
|
||||
retain_id_mask[del_index] = False
|
||||
# src_input_ids[mask_id_mask] = self.mask_token_index
|
||||
src_input_ids = src_input_ids[retain_id_mask].tolist()
|
||||
# print("src_input_ids1:", len(src_input_ids))
|
||||
src_input_ids = src_input_ids + [self.pad_token_index] * (self.maxlength - len(src_input_ids))
|
||||
# print("src_input_ids2:", len(src_input_ids))
|
||||
tgt_input_ids = tgt_input_ids + [self.pad_token_index] * (self.tgtmaxlength - len(tgt_input_ids))
|
||||
|
||||
src_input_ids = torch.from_numpy(np.array(src_input_ids)).long()
|
||||
tgt_input_ids = torch.from_numpy(np.array(tgt_input_ids)).long()
|
||||
|
||||
return src_input_ids.unsqueeze(0), tgt_input_ids.unsqueeze(0)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tgt_id)
|
||||
|
||||
@classmethod
|
||||
def get_collate_fn(cls):
|
||||
def fn(features):
|
||||
src_tensor = torch.cat([feature[0] for feature in features])
|
||||
tgt_tensor = torch.cat([feature[1] for feature in features])
|
||||
return { "src_input_ids": src_tensor, "src_attention_mask": (src_tensor != 0).long(),
|
||||
"tgt_input_ids": tgt_tensor, "tgt_attention_mask": (tgt_tensor != 0).long() }
|
||||
|
||||
return fn
|
||||
|
||||
if __name__ == "__main__":
|
||||
pretrain_max_len = 512
|
|
@ -0,0 +1,408 @@
|
|||
from tqdm import tqdm
|
||||
import os
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data.dataset import Dataset
|
||||
import torch
|
||||
import jsonlines
|
||||
|
||||
|
||||
|
||||
def load_s2s_data(args, padding_mode, split, tokenizer):
|
||||
questions = []
|
||||
if split == 'train':
|
||||
print("***** load " + args.data_name + " train src dataset*****")
|
||||
src = []
|
||||
train_src_path = os.path.join(args.data_path, args.data_name + "/org_data/train.src")
|
||||
with open(train_src_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
src.append(text)
|
||||
|
||||
print("***** load " + args.data_name + " train tgt dataset*****")
|
||||
tgt = []
|
||||
train_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/train.tgt")
|
||||
with open(train_tgt_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
tgt.append(text)
|
||||
|
||||
elif split == 'dev':
|
||||
|
||||
print("***** load " + args.data_name + " dev src dataset*****")
|
||||
src = []
|
||||
dev_src_path = os.path.join(args.data_path, args.data_name + "/org_data/dev.src")
|
||||
with open(dev_src_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
src.append(text)
|
||||
|
||||
print("***** load " + args.data_name + " dev tgt dataset*****")
|
||||
tgt = []
|
||||
dev_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/dev.tgt")
|
||||
with open(dev_tgt_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
tgt.append(text)
|
||||
|
||||
elif split == 'test':
|
||||
|
||||
print("***** load " + args.data_name + " test src dataset*****")
|
||||
src = []
|
||||
test_src_path = os.path.join(args.data_path, args.data_name + "/org_data/test.src")
|
||||
with open(test_src_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
src.append(text)
|
||||
|
||||
print("***** load " + args.data_name + " dev tgt dataset*****")
|
||||
tgt = []
|
||||
test_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/test.tgt")
|
||||
with open(test_tgt_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
tgt.append(text)
|
||||
|
||||
else:
|
||||
print("no such split of data...")
|
||||
exit(0)
|
||||
|
||||
print("example of src text: ", src[50])
|
||||
print("example of tgt text: ", tgt[50])
|
||||
|
||||
if padding_mode == 'max_len':
|
||||
if args.data_name == "squadqg_data":
|
||||
dataset = QG_dataset_Diff(src, tgt, tokenizer, src_maxlength=args.src_max_len,
|
||||
answer_maxlength=args.answer_max_len, tgt_maxlength=args.tgt_max_len)
|
||||
else:
|
||||
dataset = S2S_dataset(src, tgt, tokenizer, src_maxlength=args.src_max_len, tgt_maxlength=args.tgt_max_len)
|
||||
elif padding_mode == 'block':
|
||||
print("padding block is under realization")
|
||||
pass
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
print("example of src id lists: ", dataset[50][0])
|
||||
print("example of tgt id lists: ", dataset[50][1])
|
||||
print("total query dataset len :", len(dataset))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
'''
|
||||
for AR seq2seq training
|
||||
'''
|
||||
def load_s2s_data_AR(args, padding_mode, split, tokenizer):
|
||||
if split == 'train':
|
||||
print("***** load " + args.data_name + " train src dataset*****")
|
||||
src = []
|
||||
train_src_path = os.path.join(args.data_path, args.data_name + "/org_data/train.src")
|
||||
with open(train_src_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
src.append(text)
|
||||
|
||||
print("***** load " + args.data_name + " train tgt dataset*****")
|
||||
tgt = []
|
||||
train_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/train.tgt")
|
||||
with open(train_tgt_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
tgt.append(text)
|
||||
|
||||
elif split == 'dev':
|
||||
|
||||
print("***** load " + args.data_name + " dev src dataset*****")
|
||||
src = []
|
||||
dev_src_path = os.path.join(args.data_path, args.data_name + "/org_data/dev.src")
|
||||
with open(dev_src_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
src.append(text)
|
||||
|
||||
print("***** load " + args.data_name + " dev tgt dataset*****")
|
||||
tgt = []
|
||||
dev_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/dev.tgt")
|
||||
with open(dev_tgt_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
tgt.append(text)
|
||||
# src = src[:100]
|
||||
# tgt = tgt[:100]
|
||||
|
||||
elif split == 'test':
|
||||
|
||||
print("***** load " + args.data_name + " test src dataset*****")
|
||||
src = []
|
||||
test_src_path = os.path.join(args.data_path, args.data_name + "/org_data/test.src")
|
||||
with open(test_src_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
src.append(text)
|
||||
|
||||
print("***** load " + args.data_name + " dev tgt dataset*****")
|
||||
tgt = []
|
||||
test_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/test.tgt")
|
||||
with open(test_tgt_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
tgt.append(text)
|
||||
# src = src[:10]
|
||||
# tgt = tgt[:10]
|
||||
|
||||
else:
|
||||
print("no such split of data...")
|
||||
exit(0)
|
||||
|
||||
print("example of src text: ", src[9].replace("<S_SEP>",'\n'))
|
||||
print("example of tgt text: ", tgt[9].replace("<S_SEP>",'\n'))
|
||||
|
||||
if padding_mode == 'max_len':
|
||||
dataset = S2S_AR_dataset(src, tgt, tokenizer, src_maxlength=args.src_max_len, tgt_maxlength=args.tgt_max_len)
|
||||
elif padding_mode == 'block':
|
||||
print("padding block is under realization")
|
||||
pass
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
print("example of src id lists: ", dataset[9][0])
|
||||
print("example of tgt id lists: ", dataset[9][1])
|
||||
print("total query dataset len :", len(dataset))
|
||||
|
||||
return dataset
|
||||
|
||||
'''
|
||||
load baseline data on SeqDiffusion
|
||||
'''
|
||||
def load_s2s_jsonl_data(args, padding_mode, split, tokenizer):
|
||||
questions = []
|
||||
if split == 'train':
|
||||
print("***** load " + args.data_name + " train src and tgt dataset*****")
|
||||
src = []
|
||||
tgt = []
|
||||
|
||||
train_path = os.path.join(args.data_path, args.data_name + "/train.jsonl")
|
||||
with jsonlines.open(train_path) as reader:
|
||||
for obj in reader:
|
||||
tgt.append(obj['trg'])
|
||||
src.append(obj['src'])
|
||||
|
||||
elif split == 'dev':
|
||||
|
||||
print("***** load " + args.data_name + " dev src and tgt dataset*****")
|
||||
src = []
|
||||
tgt = []
|
||||
dev_path = os.path.join(args.data_path, args.data_name + "/valid.jsonl")
|
||||
with jsonlines.open(dev_path) as reader:
|
||||
for obj in reader:
|
||||
tgt.append(obj['trg'])
|
||||
src.append(obj['src'])
|
||||
|
||||
elif split == 'test':
|
||||
|
||||
print("***** load " + args.data_name + " test src and tgt dataset*****")
|
||||
src = []
|
||||
tgt = []
|
||||
test_path = os.path.join(args.data_path, args.data_name + "/test.jsonl")
|
||||
with jsonlines.open(test_path) as reader:
|
||||
for obj in reader:
|
||||
tgt.append(obj['trg'])
|
||||
src.append(obj['src'])
|
||||
|
||||
else:
|
||||
print("no such split of data...")
|
||||
exit(0)
|
||||
|
||||
print("example of src text: ", src[50])
|
||||
print("example of tgt text: ", tgt[50])
|
||||
|
||||
if padding_mode == 'max_len':
|
||||
dataset = S2S_dataset(src, tgt, tokenizer, src_maxlength=args.src_max_len, tgt_maxlength=args.tgt_max_len)
|
||||
elif padding_mode == 'block':
|
||||
print("padding block is under realization")
|
||||
pass
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
print("example of src id lists: ", dataset[50][0])
|
||||
print("example of tgt id lists: ", dataset[50][1])
|
||||
print("total query dataset len :", len(dataset))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class QG_dataset_Diff(Dataset):
|
||||
def __init__(self, src, tgt, tokenizer, src_maxlength=144, answer_maxlength=20, tgt_maxlength=32):
|
||||
self.src = src
|
||||
self.tgt = tgt
|
||||
self.tokenizer = tokenizer
|
||||
self.src_maxlength = src_maxlength
|
||||
self.tgt_maxlength = tgt_maxlength
|
||||
self.ans_maxlength = answer_maxlength
|
||||
|
||||
def __getitem__(self, index):
|
||||
src_example = self.src[index]
|
||||
tgt_example = self.tgt[index]
|
||||
|
||||
answer = src_example.split('[SEP]')[0].strip()
|
||||
passage = src_example.split('[SEP]')[1].strip()
|
||||
|
||||
src_input_ids = self.tokenizer.encode(passage, add_special_tokens=True,
|
||||
max_length=self.src_maxlength, truncation=True,
|
||||
padding='max_length',return_tensors='pt')
|
||||
answer_ids = self.tokenizer.encode(answer, add_special_tokens=True,
|
||||
max_length=self.ans_maxlength, truncation=True,
|
||||
padding='max_length', return_tensors='pt')
|
||||
tgt_input_ids = self.tokenizer.encode(tgt_example, add_special_tokens=True,
|
||||
max_length=self.tgt_maxlength, truncation=True,
|
||||
padding='max_length', return_tensors='pt')
|
||||
|
||||
return src_input_ids, answer_ids, tgt_input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src)
|
||||
|
||||
@classmethod
|
||||
def get_collate_fn(cls):
|
||||
def fn(features):
|
||||
src_tensor = torch.cat([feature[0] for feature in features])
|
||||
ans_tensor = torch.cat([feature[1] for feature in features])
|
||||
tgt_tensor = torch.cat([feature[2] for feature in features])
|
||||
return { "src_input_ids": src_tensor, "src_attention_mask": (src_tensor != 0).long(),
|
||||
"answer_ids": ans_tensor, "answer_mask": (ans_tensor != 0).long(),
|
||||
"tgt_input_ids": tgt_tensor, "tgt_attention_mask": (tgt_tensor != 0).long() }
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
'''
|
||||
s2s for AR model
|
||||
'''
|
||||
class S2S_AR_dataset(Dataset):
|
||||
def __init__(self, src, tgt, tokenizer, src_maxlength=144, tgt_maxlength=32):
|
||||
self.src = src
|
||||
self.tgt = tgt
|
||||
self.tokenizer = tokenizer
|
||||
self.src_maxlength = src_maxlength
|
||||
self.tgt_maxlength = tgt_maxlength
|
||||
|
||||
def __getitem__(self, index):
|
||||
src_example = self.src[index]
|
||||
tgt_example = self.tgt[index]
|
||||
|
||||
src_example.replace('<S_SEP>', '\n')
|
||||
tgt_example.replace('<S_SEP>', '\n')
|
||||
|
||||
src_input_ids = self.tokenizer.encode(src_example, add_special_tokens=True,
|
||||
max_length=self.src_maxlength, truncation=True,
|
||||
padding='max_length', return_tensors='pt')
|
||||
tgt_input_ids = self.tokenizer.encode(tgt_example, add_special_tokens=True,
|
||||
max_length=self.tgt_maxlength, truncation=True,
|
||||
padding='max_length', return_tensors='pt')
|
||||
tgt_input_ids[tgt_input_ids == 1] = -100
|
||||
|
||||
return src_input_ids, tgt_input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src)
|
||||
|
||||
@classmethod
|
||||
def get_collate_fn(cls):
|
||||
def fn(features):
|
||||
src_tensor = torch.cat([feature[0] for feature in features])
|
||||
tgt_tensor = torch.cat([feature[1] for feature in features])
|
||||
# print("src shape:", src_tensor.shape)
|
||||
# print("tgt shape:", tgt_tensor.shape)
|
||||
return { "input_ids": src_tensor, "attention_mask": (src_tensor != 0).long(),
|
||||
"labels": tgt_tensor}
|
||||
|
||||
return fn
|
||||
|
||||
class S2S_dataset(Dataset):
|
||||
def __init__(self, src, tgt, tokenizer, src_maxlength=144, tgt_maxlength=32):
|
||||
self.src = src
|
||||
self.tgt = tgt
|
||||
self.tokenizer = tokenizer
|
||||
self.src_maxlength = src_maxlength
|
||||
self.tgt_maxlength = tgt_maxlength
|
||||
|
||||
def __getitem__(self, index):
|
||||
src_example = self.src[index]
|
||||
tgt_example = self.tgt[index]
|
||||
|
||||
src_input_ids = self.tokenizer.encode(src_example, add_special_tokens=True,
|
||||
max_length=self.src_maxlength, truncation=True,
|
||||
padding='max_length',return_tensors='pt')
|
||||
tgt_input_ids = self.tokenizer.encode(tgt_example, add_special_tokens=True,
|
||||
max_length=self.tgt_maxlength, truncation=True,
|
||||
padding='max_length', return_tensors='pt')
|
||||
|
||||
return src_input_ids, tgt_input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src)
|
||||
|
||||
@classmethod
|
||||
def get_collate_fn(cls):
|
||||
def fn(features):
|
||||
src_tensor = torch.cat([feature[0] for feature in features])
|
||||
tgt_tensor = torch.cat([feature[1] for feature in features])
|
||||
return { "src_input_ids": src_tensor, "src_attention_mask": (src_tensor != 0).long(),
|
||||
"tgt_input_ids": tgt_tensor, "tgt_attention_mask": (tgt_tensor != 0).long() }
|
||||
|
||||
return fn
|
||||
|
||||
class S2S_imp_dataset(Dataset):
|
||||
def __init__(self, src, tgt, ori_gen, tokenizer, src_maxlength=144, tgt_maxlength=32):
|
||||
self.src = src
|
||||
self.tgt = tgt
|
||||
self.ori_gen = ori_gen
|
||||
self.tokenizer = tokenizer
|
||||
self.src_maxlength = src_maxlength
|
||||
self.tgt_maxlength = tgt_maxlength
|
||||
|
||||
def __getitem__(self, index):
|
||||
src_example = self.src[index]
|
||||
tgt_example = self.tgt[index]
|
||||
ori_gen_example = self.ori_gen[index]
|
||||
|
||||
src_input_ids = self.tokenizer.encode(src_example, add_special_tokens=True,
|
||||
max_length=self.src_maxlength, truncation=True,
|
||||
padding='max_length',return_tensors='pt')
|
||||
tgt_input_ids = self.tokenizer.encode(tgt_example, add_special_tokens=True,
|
||||
max_length=self.tgt_maxlength, truncation=True,
|
||||
padding='max_length', return_tensors='pt')
|
||||
ori_gen_ids = self.tokenizer.encode(ori_gen_example, add_special_tokens=True,
|
||||
max_length=self.tgt_maxlength, truncation=True,
|
||||
padding='max_length', return_tensors='pt')
|
||||
|
||||
return src_input_ids, tgt_input_ids, ori_gen_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src)
|
||||
|
||||
@classmethod
|
||||
def get_collate_fn(cls):
|
||||
def fn(features):
|
||||
src_tensor = torch.cat([feature[0] for feature in features])
|
||||
tgt_tensor = torch.cat([feature[1] for feature in features])
|
||||
ori_gen_tensor = torch.cat([feature[2] for feature in features])
|
||||
return { "src_input_ids": src_tensor, "src_attention_mask": (src_tensor != 0).long(),
|
||||
"tgt_input_ids": tgt_tensor, "tgt_attention_mask": (tgt_tensor != 0).long(),
|
||||
"ori_gen_ids": ori_gen_tensor}
|
||||
|
||||
return fn
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
|
@ -0,0 +1,343 @@
|
|||
import os
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from util import logger
|
||||
from train_util import dist_util
|
||||
import numpy as np
|
||||
|
||||
def load_data_text(data_args, emb_model=None, padding_mode='max_len', split='train', tokenizer=None, return_hidden=False):
|
||||
|
||||
'''
|
||||
init embedding model
|
||||
'''
|
||||
if split == 'train':
|
||||
if data_args.token_emb_type == 'random' and emb_model is None:
|
||||
print('loading initialized random embeddings. ')
|
||||
emb_model = None
|
||||
elif data_args.token_emb_type == 'pretrain' and emb_model is not None:
|
||||
print('loading embeddings from pretraining embedding. ')
|
||||
|
||||
'''
|
||||
load query text dataset (only query) according to split
|
||||
'''
|
||||
if data_args.model_arch == 'transformer':
|
||||
dataset = get_query_corpus(data_args, padding_mode=padding_mode, split=split,
|
||||
tokenizer=tokenizer)
|
||||
elif data_args.model_arch == 's2s_CAT':
|
||||
dataset = get_pandq_corpus(data_args, padding_mode=padding_mode, split=split,
|
||||
tokenizer=tokenizer)
|
||||
else:
|
||||
return NotImplementedError
|
||||
'''
|
||||
load embedding model
|
||||
'''
|
||||
# if data_args.token_emb_type == 'random' and emb_model is None:
|
||||
# emb_model = torch.nn.Embedding(tokenizer.vocab_size, data_args.in_channel)
|
||||
# print('initializing the random embeddings', emb_model)
|
||||
# torch.nn.init.normal_(emb_model.weight)
|
||||
# path_save = f'{data_args.checkpoint_path}/random_emb.torch'
|
||||
# print(f'save the random encoder to {data_args.checkpoint_path}/random_emb.torch')
|
||||
# torch.save(emb_model.state_dict(), path_save)
|
||||
# elif data_args.token_emb_type == 'pretrain':
|
||||
# # TODO: finish pretrain embedding setting
|
||||
# print('initializing the pretrain embeddings', emb_model)
|
||||
# else:
|
||||
# return NotImplementedError
|
||||
|
||||
if return_hidden:
|
||||
'''
|
||||
encode input ids to input embedding
|
||||
load temp dataloader to input in embedding model, getting hidden_state (word embedding)
|
||||
'''
|
||||
query_dataloader = DataLoader(dataset, batch_size=1024, drop_last=False,
|
||||
num_workers=20, collate_fn=Question_dataset.get_collate_fn())
|
||||
emb_model.to(dist_util.dev())
|
||||
hidden_state_set = []
|
||||
with torch.no_grad():
|
||||
for k, (ids, text_ids, text_mask) in enumerate(tqdm(query_dataloader)):
|
||||
hidden_state = emb_model(text_ids.long().to(dist_util.dev()))
|
||||
# hidden_state :(batch_size, seq_len, hidden_size)
|
||||
hidden_state = hidden_state.detach().cpu()
|
||||
hidden_state_set.append(hidden_state)
|
||||
hidden_state_set = torch.cat(hidden_state_set, dim=0)
|
||||
|
||||
'''
|
||||
load query embedding dataloader
|
||||
'''
|
||||
query_emb_dataset = Text_Hidden_dataset(dataset, hidden_state_set)
|
||||
# query_dataloader = DataLoader(query_emb_dataset, batch_size=data_args.batch_size, drop_last=False,
|
||||
# num_workers=20, collate_fn=Text_Hidden_dataset.get_collate_fn())
|
||||
|
||||
emb_model.cpu()
|
||||
return query_emb_dataset, emb_model
|
||||
|
||||
else:
|
||||
return dataset, emb_model
|
||||
|
||||
# for input_ids in dataset['word_ids']:
|
||||
# if data_args.experiment.startswith('random'):
|
||||
# hidden_state = model(torch.tensor(input_ids))
|
||||
# elif data_args.experiment == 'gpt2_pre_compress':
|
||||
# input_ids2 = torch.tensor(input_ids).to(model.device)
|
||||
# input_embs = model.transformer.wte(input_ids2) # input_embs
|
||||
# hidden_state = model.down_proj(input_embs)
|
||||
# hidden_state = hidden_state * data_args.emb_scale_factor
|
||||
# elif data_args.experiment == 'glove':
|
||||
# hidden_state = model(torch.tensor(input_ids))
|
||||
# result_train_lst.append({'input_ids': input_ids, 'hidden_states': hidden_state.cpu().tolist()})
|
||||
|
||||
|
||||
def get_query_corpus(args, padding_mode, split, tokenizer):
|
||||
|
||||
questions = []
|
||||
if split == 'train':
|
||||
print("***** load train query dataset*****")
|
||||
train_question_path = os.path.join(args.data_path, "train.query.txt")
|
||||
with open(train_question_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
id, text = line.split('\t')
|
||||
questions.append([int(id), text])
|
||||
elif split == 'dev':
|
||||
dev_question_path = os.path.join(args.data_path, "dev.query.txt")
|
||||
with open(dev_question_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
id, text = line.split('\t')
|
||||
questions.append([int(id), text])
|
||||
else:
|
||||
print("no such split of data...")
|
||||
exit(0)
|
||||
|
||||
print("example of questions text: ", questions[50])
|
||||
|
||||
if padding_mode == 'max_len':
|
||||
dataset = Question_dataset(questions, tokenizer, maxlength=args.text_max_len)
|
||||
elif padding_mode == 'block':
|
||||
print("padding block is under realization")
|
||||
pass
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
print("example of questions id lists: ", dataset[50])
|
||||
print("total query dataset len :", len(dataset))
|
||||
|
||||
return dataset
|
||||
|
||||
def get_pandq_corpus(args, padding_mode, split, tokenizer):
|
||||
questions = []
|
||||
if split == 'train':
|
||||
print("***** load train query dataset*****")
|
||||
train_question_path = os.path.join(args.data_path, "train.query.txt")
|
||||
with open(train_question_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
id, text = line.split('\t')
|
||||
questions.append([int(id), text])
|
||||
print("***** load query to passage qrel data*****")
|
||||
qrel_path = os.path.join(args.data_path, "qrels.train.tsv")
|
||||
qids_to_relevant_passageids = load_train_reference_from_stream(qrel_path)
|
||||
|
||||
|
||||
elif split == 'dev':
|
||||
dev_question_path = os.path.join(args.data_path, "dev.query.txt")
|
||||
with open(dev_question_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
id, text = line.split('\t')
|
||||
questions.append([int(id), text])
|
||||
print("***** load query to passage qrel data*****")
|
||||
qrel_path = os.path.join(args.data_path, "qrels.dev.tsv")
|
||||
qids_to_relevant_passageids = load_dev_reference_from_stream(qrel_path)
|
||||
else:
|
||||
print("no such split of data...")
|
||||
exit(0)
|
||||
|
||||
print("***** load passages dataset*****")
|
||||
passage_title_path = os.path.join(args.data_path, "para.title.txt")
|
||||
passage_ctx_path = os.path.join(args.data_path, "para.txt")
|
||||
passage_title = load_id_text(passage_title_path)
|
||||
passages = {}
|
||||
with open(passage_ctx_path) as inp:
|
||||
for line in tqdm(inp):
|
||||
line = line.strip()
|
||||
id, text = line.split('\t')
|
||||
passages[int(id)] = (text, passage_title.get(id, '-'))
|
||||
|
||||
print("example of questions text: ", questions[50])
|
||||
qid = questions[50][0]
|
||||
pid = qids_to_relevant_passageids[qid][0]
|
||||
print("example of passage text --title: ", passages[pid][1])
|
||||
print("example of passage text --passage: ", passages[pid][0])
|
||||
|
||||
|
||||
if padding_mode == 'max_len':
|
||||
dataset = PandQ_dataset(questions, qids_to_relevant_passageids, passages,
|
||||
tokenizer, q_maxlength=args.text_max_len, p_maxlength=args.pas_max_len)
|
||||
elif padding_mode == 'block':
|
||||
print("padding block is under realization")
|
||||
pass
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
print("example of questions id lists: ", dataset[50][0])
|
||||
print("example of passage id lists: ", dataset[50][1])
|
||||
print("total query dataset len :", len(dataset))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class Text_Hidden_dataset(Dataset):
|
||||
def __init__(self, query_dataset, hidden_state_set):
|
||||
self.query_dataset = query_dataset
|
||||
self.hidden_state_set = hidden_state_set
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = self.query_dataset[index]
|
||||
hidden_state = self.hidden_state_set[index]
|
||||
|
||||
return example[0], example[1], hidden_state
|
||||
|
||||
'''
|
||||
query_ids : len(np list) = batch_size query
|
||||
input_ids : (batch_size, seq_len)
|
||||
attention_mask : (batch_size, seq_len)
|
||||
hidden_state : (batch_size, seq_len, embedding)
|
||||
'''
|
||||
@classmethod
|
||||
def get_collate_fn(cls):
|
||||
def fn(features):
|
||||
id_list = [feature[0] for feature in features]
|
||||
q_tensor = torch.cat([feature[1] for feature in features])
|
||||
hidden_state = torch.cat([feature[2] for feature in features])
|
||||
return {"query_ids":np.array(id_list), "input_ids":q_tensor,
|
||||
"attention_mask":(q_tensor != 0).long(), "hidden_state":hidden_state}
|
||||
|
||||
return fn
|
||||
|
||||
class PandQ_dataset(Dataset):
|
||||
def __init__(self, questions, qids_to_relevant_passageids, passages,
|
||||
tokenizer, q_maxlength=32, p_maxlength=144):
|
||||
self.questions = questions
|
||||
self.tokenizer = tokenizer
|
||||
self.q_maxlength = q_maxlength
|
||||
self.p_maxlength = p_maxlength
|
||||
self.qids_to_relevant_passageids = qids_to_relevant_passageids
|
||||
self.passages = passages
|
||||
|
||||
def __getitem__(self, index):
|
||||
q_example = self.questions[index]
|
||||
query_id = q_example[0]
|
||||
q_input_ids = self.tokenizer.encode(q_example[1], add_special_tokens=True,
|
||||
max_length=self.q_maxlength, truncation=True,
|
||||
padding='max_length',return_tensors='pt')
|
||||
rel_passage_id = self.qids_to_relevant_passageids[query_id][0]
|
||||
passage_example = self.passages[rel_passage_id]
|
||||
text = passage_example[0]
|
||||
title = passage_example[1]
|
||||
p_input_ids = self.tokenizer.encode(title, text_pair=text, add_special_tokens=True,
|
||||
max_length=self.p_maxlength, truncation=True,
|
||||
padding='max_length', return_tensors='pt')
|
||||
|
||||
return q_input_ids, p_input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.questions)
|
||||
|
||||
@classmethod
|
||||
def get_collate_fn(cls):
|
||||
def fn(features):
|
||||
q_tensor = torch.cat([feature[0] for feature in features])
|
||||
p_tensor = torch.cat([feature[1] for feature in features])
|
||||
return { "q_input_ids": q_tensor, "q_attention_mask": (q_tensor != 0).long(),
|
||||
"p_input_ids": p_tensor, "p_attention_mask": (p_tensor != 0).long() }
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
|
||||
class Question_dataset(Dataset):
|
||||
def __init__(self, questions, tokenizer,maxlength=32):
|
||||
self.questions = questions
|
||||
self.tokenizer = tokenizer
|
||||
self.maxlength = maxlength
|
||||
def __getitem__(self, index):
|
||||
example = self.questions[index]
|
||||
input_ids = self.tokenizer.encode(example[1], add_special_tokens=True,
|
||||
max_length=self.maxlength, truncation=True,
|
||||
padding='max_length',return_tensors='pt')
|
||||
return example[0], input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.questions)
|
||||
|
||||
@classmethod
|
||||
def get_collate_fn(cls):
|
||||
def fn(features):
|
||||
id_list = [feature[0] for feature in features]
|
||||
q_tensor = torch.cat([feature[1] for feature in features])
|
||||
return np.array(id_list), q_tensor, (q_tensor != 0).long()
|
||||
return fn
|
||||
|
||||
|
||||
def load_dev_reference_from_stream(path_to_reference):
|
||||
"""Load Reference reference relevant passages
|
||||
Args:f (stream): stream to load.
|
||||
Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints).
|
||||
"""
|
||||
qids_to_relevant_passageids = {}
|
||||
with open(path_to_reference, 'r') as f:
|
||||
for l in f:
|
||||
try:
|
||||
l = l.strip().split('\t')
|
||||
qid = int(l[0])
|
||||
if qid in qids_to_relevant_passageids:
|
||||
pass
|
||||
else:
|
||||
qids_to_relevant_passageids[qid] = []
|
||||
qids_to_relevant_passageids[qid].append(int(l[1]))
|
||||
except:
|
||||
raise IOError('\"%s\" is not valid format' % l)
|
||||
return qids_to_relevant_passageids
|
||||
|
||||
|
||||
def csv_reader(fd, delimiter='\t', trainer_id=0, trainer_num=1):
|
||||
def gen():
|
||||
for i, line in tqdm(enumerate(fd)):
|
||||
if i % trainer_num == trainer_id:
|
||||
slots = line.rstrip('\n').split(delimiter)
|
||||
if len(slots) == 1:
|
||||
yield slots,
|
||||
else:
|
||||
yield slots
|
||||
return gen()
|
||||
def load_train_reference_from_stream(input_file, trainer_id=0, trainer_num=1):
|
||||
"""Reads a tab separated value file."""
|
||||
with open(input_file, 'r', encoding='utf8') as f:
|
||||
reader = csv_reader(f, trainer_id=trainer_id, trainer_num=trainer_num)
|
||||
#headers = 'query_id\tpos_id\tneg_id'.split('\t')
|
||||
|
||||
#Example = namedtuple('Example', headers)
|
||||
qrel = {}
|
||||
for [topicid, _, docid, rel] in reader:
|
||||
topicid = int(topicid)
|
||||
assert rel == "1"
|
||||
if topicid in qrel:
|
||||
qrel[topicid].append(int(docid))
|
||||
else:
|
||||
qrel[topicid] = [int(docid)]
|
||||
return qrel
|
||||
|
||||
def load_id_text(file_name):
|
||||
"""load tsv files"""
|
||||
id_text = {}
|
||||
with open(file_name) as inp:
|
||||
for line in tqdm(inp):
|
||||
line = line.strip()
|
||||
id, text = line.split('\t')
|
||||
id_text[id] = text
|
||||
return id_text
|
||||
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,157 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def create_named_schedule_sampler(name, diffusion):
|
||||
"""
|
||||
Create a ScheduleSampler from a library of pre-defined samplers.
|
||||
|
||||
:param name: the name of the sampler.
|
||||
:param diffusion: the diffusion object to sample for.
|
||||
"""
|
||||
if name == "uniform":
|
||||
print("using uniform time schedule sampler")
|
||||
return UniformSampler(diffusion)
|
||||
elif name == "loss-second-moment":
|
||||
print("using loss-second-moment time schedule sampler")
|
||||
return LossSecondMomentResampler(diffusion)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
||||
|
||||
|
||||
class ScheduleSampler(ABC):
|
||||
"""
|
||||
A distribution over timesteps in the diffusion process, intended to reduce
|
||||
variance of the objective.
|
||||
|
||||
By default, samplers perform unbiased importance sampling, in which the
|
||||
objective's mean is unchanged.
|
||||
However, subclasses may override sample() to change how the resampled
|
||||
terms are reweighted, allowing for actual changes in the objective.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def weights(self):
|
||||
"""
|
||||
Get a numpy array of weights, one per diffusion step.
|
||||
|
||||
The weights needn't be normalized, but must be positive.
|
||||
"""
|
||||
|
||||
def sample(self, batch_size, device):
|
||||
"""
|
||||
Importance-sample timesteps for a batch.
|
||||
|
||||
:param batch_size: the number of timesteps.
|
||||
:param device: the torch device to save to.
|
||||
:return: a tuple (timesteps, weights):
|
||||
- timesteps: a tensor of timestep indices.
|
||||
- weights: a tensor of weights to scale the resulting losses.
|
||||
"""
|
||||
w = self.weights()
|
||||
# print(w)
|
||||
p = w / np.sum(w)
|
||||
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
||||
indices = th.from_numpy(indices_np).long().to(device)
|
||||
weights_np = 1 / (len(p) * p[indices_np])
|
||||
weights = th.from_numpy(weights_np).float().to(device)
|
||||
return indices, weights
|
||||
|
||||
|
||||
class UniformSampler(ScheduleSampler):
|
||||
def __init__(self, diffusion):
|
||||
self.diffusion = diffusion
|
||||
self._weights = np.ones([diffusion.num_timesteps])
|
||||
|
||||
def weights(self):
|
||||
return self._weights
|
||||
|
||||
|
||||
class LossAwareSampler(ScheduleSampler):
|
||||
def update_with_local_losses(self, local_ts, local_losses):
|
||||
"""
|
||||
Update the reweighting using losses from a model.
|
||||
|
||||
Call this method from each rank with a batch of timesteps and the
|
||||
corresponding losses for each of those timesteps.
|
||||
This method will perform synchronization to make sure all of the ranks
|
||||
maintain the exact same reweighting.
|
||||
|
||||
:param local_ts: an integer Tensor of timesteps.
|
||||
:param local_losses: a 1D Tensor of losses.
|
||||
"""
|
||||
batch_sizes = [
|
||||
th.tensor([0], dtype=th.int32, device=local_ts.device)
|
||||
for _ in range(dist.get_world_size())
|
||||
]
|
||||
dist.all_gather(
|
||||
batch_sizes,
|
||||
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
||||
)
|
||||
|
||||
# Pad all_gather batches to be the maximum batch size.
|
||||
batch_sizes = [x.item() for x in batch_sizes]
|
||||
max_bs = max(batch_sizes)
|
||||
|
||||
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
|
||||
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
|
||||
dist.all_gather(timestep_batches, local_ts)
|
||||
dist.all_gather(loss_batches, local_losses)
|
||||
timesteps = [
|
||||
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
|
||||
]
|
||||
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
||||
self.update_with_all_losses(timesteps, losses)
|
||||
|
||||
@abstractmethod
|
||||
def update_with_all_losses(self, ts, losses):
|
||||
"""
|
||||
Update the reweighting using losses from a model.
|
||||
|
||||
Sub-classes should override this method to update the reweighting
|
||||
using losses from the model.
|
||||
|
||||
This method directly updates the reweighting without synchronizing
|
||||
between workers. It is called by update_with_local_losses from all
|
||||
ranks with identical arguments. Thus, it should have deterministic
|
||||
behavior to maintain state across workers.
|
||||
|
||||
:param ts: a list of int timesteps.
|
||||
:param losses: a list of float losses, one per timestep.
|
||||
"""
|
||||
|
||||
|
||||
class LossSecondMomentResampler(LossAwareSampler):
|
||||
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
||||
self.diffusion = diffusion
|
||||
self.history_per_term = history_per_term
|
||||
self.uniform_prob = uniform_prob
|
||||
self._loss_history = np.zeros(
|
||||
[diffusion.num_timesteps, history_per_term], dtype=np.float64
|
||||
)
|
||||
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
|
||||
|
||||
def weights(self):
|
||||
if not self._warmed_up():
|
||||
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
||||
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
|
||||
weights /= np.sum(weights)
|
||||
weights *= 1 - self.uniform_prob
|
||||
weights += self.uniform_prob / len(weights)
|
||||
return weights
|
||||
|
||||
def update_with_all_losses(self, ts, losses):
|
||||
for t, loss in zip(ts, losses):
|
||||
if self._loss_counts[t] == self.history_per_term:
|
||||
# Shift out the oldest loss term.
|
||||
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
||||
self._loss_history[t, -1] = loss
|
||||
else:
|
||||
self._loss_history[t, self._loss_counts[t]] = loss
|
||||
self._loss_counts[t] += 1
|
||||
|
||||
def _warmed_up(self):
|
||||
return (self._loss_counts == self.history_per_term).all()
|
|
@ -0,0 +1,133 @@
|
|||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
from diffusion_util.gaussian_diffusion import GaussianDiffusion
|
||||
|
||||
|
||||
def space_timesteps(num_timesteps, section_counts):
|
||||
"""
|
||||
Create a list of timesteps to use from an original diffusion process,
|
||||
given the number of timesteps we want to take from equally-sized portions
|
||||
of the original process.
|
||||
|
||||
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
||||
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
||||
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
||||
|
||||
If the stride is a string starting with "ddim", then the fixed striding
|
||||
from the DDIM paper is used, and only one section is allowed.
|
||||
|
||||
:param num_timesteps: the number of diffusion steps in the original
|
||||
process to divide up.
|
||||
:param section_counts: either a list of numbers, or a string containing
|
||||
comma-separated numbers, indicating the step count
|
||||
per section. As a special case, use "ddimN" where N
|
||||
is a number of steps to use the striding from the
|
||||
DDIM paper.
|
||||
:return: a set of diffusion steps from the original process to use.
|
||||
"""
|
||||
if isinstance(section_counts, str):
|
||||
if section_counts.startswith("ddim"):
|
||||
desired_count = int(section_counts[len("ddim") :])
|
||||
for i in range(1, num_timesteps):
|
||||
if len(range(0, num_timesteps, i)) == desired_count:
|
||||
return set(range(0, num_timesteps, i))
|
||||
raise ValueError(
|
||||
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
||||
)
|
||||
section_counts = [int(x) for x in section_counts.split(",")]
|
||||
size_per = num_timesteps // len(section_counts)
|
||||
extra = num_timesteps % len(section_counts)
|
||||
start_idx = 0
|
||||
all_steps = []
|
||||
for i, section_count in enumerate(section_counts):
|
||||
size = size_per + (1 if i < extra else 0)
|
||||
if size < section_count:
|
||||
raise ValueError(
|
||||
f"cannot divide section of {size} steps into {section_count}"
|
||||
)
|
||||
if section_count <= 1:
|
||||
frac_stride = 1
|
||||
else:
|
||||
frac_stride = (size - 1) / (section_count - 1)
|
||||
cur_idx = 0.0
|
||||
taken_steps = []
|
||||
for _ in range(section_count):
|
||||
taken_steps.append(start_idx + round(cur_idx))
|
||||
cur_idx += frac_stride
|
||||
all_steps += taken_steps
|
||||
start_idx += size
|
||||
return set(all_steps)
|
||||
|
||||
|
||||
class SpacedDiffusion(GaussianDiffusion):
|
||||
"""
|
||||
A diffusion process which can skip steps in a base diffusion process.
|
||||
|
||||
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
||||
original diffusion process to retain.
|
||||
:param kwargs: the kwargs to create the base diffusion process.
|
||||
"""
|
||||
|
||||
def __init__(self, use_timesteps, **kwargs):
|
||||
self.use_timesteps = set(use_timesteps)
|
||||
self.timestep_map = []
|
||||
self.original_num_steps = len(kwargs["betas"])
|
||||
|
||||
# print(kwargs.keys())
|
||||
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
||||
|
||||
# 根据时间采样表中存在的 t 计算一个新的 beta(作用不明)
|
||||
last_alpha_cumprod = 1.0
|
||||
new_betas = []
|
||||
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
||||
if i in self.use_timesteps:
|
||||
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
||||
last_alpha_cumprod = alpha_cumprod
|
||||
self.timestep_map.append(i)
|
||||
kwargs["betas"] = np.array(new_betas)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def p_mean_variance(
|
||||
self, model, *args, **kwargs
|
||||
): # pylint: disable=signature-differs
|
||||
# print('called p_mean_var')
|
||||
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
||||
|
||||
def training_losses(
|
||||
self, model, *args, **kwargs
|
||||
): # pylint: disable=signature-differs
|
||||
# print('called training_losses')
|
||||
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
||||
|
||||
def _wrap_model(self, model):
|
||||
if isinstance(model, _WrappedModel):
|
||||
return model
|
||||
return _WrappedModel(
|
||||
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
|
||||
)
|
||||
|
||||
def _scale_timesteps(self, t):
|
||||
# Scaling is done by the wrapped model.
|
||||
return t
|
||||
|
||||
|
||||
class _WrappedModel:
|
||||
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
|
||||
self.model = model
|
||||
self.timestep_map = timestep_map
|
||||
self.rescale_timesteps = rescale_timesteps
|
||||
self.original_num_steps = original_num_steps
|
||||
|
||||
def __call__(self, x, ts, **kwargs):
|
||||
# print(ts)
|
||||
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
||||
new_ts = map_tensor[ts]
|
||||
# print(new_ts)
|
||||
if self.rescale_timesteps:
|
||||
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
||||
# temp = self.model(x, new_ts, **kwargs)
|
||||
# print(temp.shape)
|
||||
# return temp
|
||||
# print(new_ts)
|
||||
return self.model(x, new_ts, **kwargs)
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 187 KiB |
|
@ -0,0 +1,269 @@
|
|||
import os
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
import string
|
||||
import rouge
|
||||
|
||||
def get_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# out path
|
||||
parser.add_argument('--generate_path', type=str, default='', help='output path')
|
||||
parser.add_argument('--num_samples', type=int, default=50, help='sample query')
|
||||
|
||||
# data args
|
||||
parser.add_argument('--data_path', type=str, default='', help='data path')
|
||||
parser.add_argument('--data_name', type=str, default='', help='data name')
|
||||
parser.add_argument('--batch_size', type=int, default=64, help='')
|
||||
|
||||
# seed
|
||||
parser.add_argument('--seed', type=int, default=101, help='')
|
||||
parser.add_argument('--n_gpu', type=int, default=4, help='')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
_tok_dict = {"(": "-lrb-", ")": "-rrb-",
|
||||
"[": "-lsb-", "]": "-rsb-",
|
||||
"{": "-lcb-", "}": "-rcb-",
|
||||
"[UNK]": "UNK", '&': '&', '<': '<', '>': '>'}
|
||||
|
||||
def _is_digit(w):
|
||||
for ch in w:
|
||||
if not(ch.isdigit() or ch == ','):
|
||||
return False
|
||||
return True
|
||||
|
||||
def fix_tokenization(text):
|
||||
input_tokens = text.split()
|
||||
output_tokens = []
|
||||
has_left_quote = False
|
||||
has_left_single_quote = False
|
||||
|
||||
i = 0
|
||||
prev_dash = False
|
||||
while i < len(input_tokens):
|
||||
tok = input_tokens[i]
|
||||
flag_prev_dash = False
|
||||
if tok in _tok_dict.keys():
|
||||
output_tokens.append(_tok_dict[tok])
|
||||
i += 1
|
||||
elif tok == "\"":
|
||||
if has_left_quote:
|
||||
output_tokens.append("''")
|
||||
else:
|
||||
output_tokens.append("``")
|
||||
has_left_quote = not has_left_quote
|
||||
i += 1
|
||||
elif tok == "'" and len(output_tokens) > 0 and output_tokens[-1].endswith("n") and i < len(input_tokens) - 1 and input_tokens[i + 1] == "t":
|
||||
output_tokens[-1] = output_tokens[-1][:-1]
|
||||
output_tokens.append("n't")
|
||||
i += 2
|
||||
elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[i + 1] in ("s", "d", "ll"):
|
||||
output_tokens.append("'"+input_tokens[i + 1])
|
||||
i += 2
|
||||
elif tok == "'":
|
||||
if has_left_single_quote:
|
||||
output_tokens.append("'")
|
||||
else:
|
||||
output_tokens.append("`")
|
||||
has_left_single_quote = not has_left_single_quote
|
||||
i += 1
|
||||
elif tok == "." and i < len(input_tokens) - 2 and input_tokens[i + 1] == "." and input_tokens[i + 2] == ".":
|
||||
output_tokens.append("...")
|
||||
i += 3
|
||||
elif tok == "," and len(output_tokens) > 0 and _is_digit(output_tokens[-1]) and i < len(input_tokens) - 1 and _is_digit(input_tokens[i + 1]):
|
||||
# $ 3 , 000 -> $ 3,000
|
||||
output_tokens[-1] += ','+input_tokens[i + 1]
|
||||
i += 2
|
||||
elif tok == "." and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and input_tokens[i + 1].isdigit():
|
||||
# 3 . 03 -> $ 3.03
|
||||
output_tokens[-1] += '.'+input_tokens[i + 1]
|
||||
i += 2
|
||||
elif tok == "." and len(output_tokens) > 0 and len(output_tokens[-1]) == 1 and output_tokens[-1].isupper() and i < len(input_tokens) - 2 and len(input_tokens[i + 1]) == 1 and input_tokens[i + 1].isupper() and input_tokens[i + 2] == '.':
|
||||
# U . N . -> U.N.
|
||||
k = i+3
|
||||
while k+2 < len(input_tokens):
|
||||
if len(input_tokens[k + 1]) == 1 and input_tokens[k + 1].isupper() and input_tokens[k + 2] == '.':
|
||||
k += 2
|
||||
else:
|
||||
break
|
||||
output_tokens[-1] += ''.join(input_tokens[i:k])
|
||||
i += 2
|
||||
elif tok == "-":
|
||||
if i < len(input_tokens) - 1 and input_tokens[i + 1] == "-":
|
||||
output_tokens.append("--")
|
||||
i += 2
|
||||
elif i == len(input_tokens) - 1 or i == 0:
|
||||
output_tokens.append("-")
|
||||
i += 1
|
||||
elif output_tokens[-1] not in string.punctuation and input_tokens[i + 1][0] not in string.punctuation:
|
||||
output_tokens[-1] += "-"
|
||||
i += 1
|
||||
flag_prev_dash = True
|
||||
else:
|
||||
output_tokens.append("-")
|
||||
i += 1
|
||||
elif prev_dash and len(output_tokens) > 0 and tok[0] not in string.punctuation:
|
||||
output_tokens[-1] += tok
|
||||
i += 1
|
||||
else:
|
||||
output_tokens.append(tok)
|
||||
i += 1
|
||||
prev_dash = flag_prev_dash
|
||||
return " ".join(output_tokens)
|
||||
|
||||
def process_eval(args, gen_list, tgt_list):
|
||||
evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], max_n=2,
|
||||
limit_length=False, apply_avg=True, weight_factor=1.2)
|
||||
|
||||
max_score = []
|
||||
avg_score = []
|
||||
lowest_score = []
|
||||
best_gen_list = []
|
||||
# max score,avg score,lowest score
|
||||
for index, sentence in enumerate(tqdm(gen_list)):
|
||||
if index % args.num_samples == 0:
|
||||
max_score_dict = {'rouge_1':0.0,'rouge_2':0.0,'rouge_l':0.0}
|
||||
avg_score_dict = {'rouge_1': 0.0, 'rouge_2': 0.0, 'rouge_l': 0.0}
|
||||
low_score_dict = {'rouge_1': 1.0, 'rouge_2': 1.0, 'rouge_l': 1.0}
|
||||
target = tgt_list[index // args.num_samples]
|
||||
scores = evaluator.get_scores([sentence], [[target]])
|
||||
rouge_1 = scores['rouge-1']['f']
|
||||
rouge_2 = scores['rouge-2']['f']
|
||||
rouge_l = scores['rouge-l']['f']
|
||||
# max
|
||||
if rouge_2 >= max_score_dict['rouge_2']:
|
||||
if rouge_2 != 0:
|
||||
max_score_dict['rouge_1'] = rouge_1
|
||||
max_score_dict['rouge_2'] = rouge_2
|
||||
max_score_dict['rouge_l'] = rouge_l
|
||||
best_sentence = sentence
|
||||
else:
|
||||
if rouge_1 >= max_score_dict['rouge_1']:
|
||||
max_score_dict['rouge_1'] = rouge_1
|
||||
max_score_dict['rouge_2'] = rouge_2
|
||||
max_score_dict['rouge_l'] = rouge_l
|
||||
best_sentence = sentence
|
||||
# avg
|
||||
avg_score_dict['rouge_1'] += rouge_1
|
||||
avg_score_dict['rouge_2'] += rouge_2
|
||||
avg_score_dict['rouge_l'] += rouge_l
|
||||
# min
|
||||
if rouge_2 < low_score_dict['rouge_2']:
|
||||
low_score_dict['rouge_1'] = rouge_1
|
||||
low_score_dict['rouge_2'] = rouge_2
|
||||
low_score_dict['rouge_l'] = rouge_l
|
||||
if (index + 1) % args.num_samples == 0:
|
||||
max_score.append(max_score_dict)
|
||||
best_gen_list.append(best_sentence)
|
||||
avg_score_dict['rouge_1'] = avg_score_dict['rouge_1'] / args.num_samples
|
||||
avg_score_dict['rouge_2'] = avg_score_dict['rouge_2'] / args.num_samples
|
||||
avg_score_dict['rouge_l'] = avg_score_dict['rouge_l'] / args.num_samples
|
||||
avg_score.append(avg_score_dict)
|
||||
lowest_score.append(low_score_dict)
|
||||
|
||||
rouge_1 = 0
|
||||
rouge_2 = 0
|
||||
rouge_l = 0
|
||||
for score_dict in max_score:
|
||||
rouge_1 += score_dict['rouge_1']
|
||||
rouge_2 += score_dict['rouge_2']
|
||||
rouge_l += score_dict['rouge_l']
|
||||
max_rouge_1 = rouge_1 / len(max_score)
|
||||
max_rouge_2 = rouge_2 / len(max_score)
|
||||
max_rouge_l = rouge_l / len(max_score)
|
||||
|
||||
rouge_1 = 0
|
||||
rouge_2 = 0
|
||||
rouge_l = 0
|
||||
for score_dict in avg_score:
|
||||
rouge_1 += score_dict['rouge_1']
|
||||
rouge_2 += score_dict['rouge_2']
|
||||
rouge_l += score_dict['rouge_l']
|
||||
avg_rouge_1 = rouge_1 / len(max_score)
|
||||
avg_rouge_2 = rouge_2 / len(max_score)
|
||||
avg_rouge_l = rouge_l / len(max_score)
|
||||
|
||||
rouge_1 = 0
|
||||
rouge_2 = 0
|
||||
rouge_l = 0
|
||||
for score_dict in lowest_score:
|
||||
rouge_1 += score_dict['rouge_1']
|
||||
rouge_2 += score_dict['rouge_2']
|
||||
rouge_l += score_dict['rouge_l']
|
||||
min_rouge_1 = rouge_1 / len(max_score)
|
||||
min_rouge_2 = rouge_2 / len(max_score)
|
||||
min_rouge_l = rouge_l / len(max_score)
|
||||
|
||||
scores = {'max_rouge_1':max_rouge_1, 'max_rouge_2':max_rouge_2, 'max_rouge_l':max_rouge_l,
|
||||
'min_rouge_1':min_rouge_1, 'min_rouge_2':min_rouge_2, 'min_rouge_l':min_rouge_l,
|
||||
'avg_rouge_1':avg_rouge_1, 'avg_rouge_2':avg_rouge_2, 'avg_rouge_l':avg_rouge_l }
|
||||
|
||||
return scores, best_gen_list
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
args = get_arguments()
|
||||
tgt = []
|
||||
test_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/test.tgt")
|
||||
with open(test_tgt_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
# text = fix_tokenization(line)
|
||||
tgt.append(text)
|
||||
|
||||
final_scores = {'max_rouge_1': 0.0, 'max_rouge_2': 0.0, 'max_rouge_l': 0.0,
|
||||
'min_rouge_1': 0.0, 'min_rouge_2': 0.0, 'min_rouge_l': 0.0,
|
||||
'avg_rouge_1': 0.0, 'avg_rouge_2': 0.0, 'avg_rouge_l': 0.0}
|
||||
|
||||
tgt_offset = 0
|
||||
best_gen_list_total = []
|
||||
|
||||
for i in range(args.n_gpu):
|
||||
|
||||
gen_list_total = []
|
||||
for epoch in range(args.num_samples):
|
||||
gen_list = []
|
||||
gen_text_path = os.path.join(args.generate_path, "rank" + str(i)+"_gen_seed_" + str(args.seed) +
|
||||
"_num" + str(args.num_samples) + "_epoch" + str(epoch+1) + ".txt")
|
||||
|
||||
with open(gen_text_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip()
|
||||
text = line
|
||||
gen_list.append(text)
|
||||
gen_list_total.append(gen_list)
|
||||
|
||||
gen_peace = []
|
||||
for index in range(len(gen_list_total[0])):
|
||||
for i in range(args.num_samples):
|
||||
gen_peace.append(gen_list_total[i][index])
|
||||
|
||||
gen_len = len(gen_peace) // args.num_samples
|
||||
scores, best_gen_list = process_eval(args, gen_peace, tgt[tgt_offset:tgt_offset+gen_len])
|
||||
best_gen_list_total.extend(best_gen_list)
|
||||
for key, values in scores.items():
|
||||
final_scores[key] += values
|
||||
print("scores on gpu ", i)
|
||||
print(scores)
|
||||
tgt_offset += gen_len
|
||||
for key, values in final_scores.items():
|
||||
final_scores[key] = values / args.n_gpu
|
||||
|
||||
print("final score :")
|
||||
print(final_scores)
|
||||
|
||||
assert len(best_gen_list_total) == len(tgt)
|
||||
print("store best gen list ...")
|
||||
out_path = os.path.join(args.generate_path, "best_gen_list.txt")
|
||||
with open(out_path, 'w') as f:
|
||||
for sentence in best_gen_list_total:
|
||||
f.write(sentence + '\n')
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,277 @@
|
|||
import os
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
import string
|
||||
import rouge
|
||||
|
||||
def get_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# out path
|
||||
parser.add_argument('--generate_path', type=str, default='', help='output path')
|
||||
parser.add_argument('--num_samples', type=int, default=50, help='sample query')
|
||||
|
||||
# data args
|
||||
parser.add_argument('--data_path', type=str, default='', help='data path')
|
||||
parser.add_argument('--data_name', type=str, default='', help='data name')
|
||||
parser.add_argument('--batch_size', type=int, default=64, help='')
|
||||
|
||||
# seed
|
||||
parser.add_argument('--seed', type=int, default=101, help='')
|
||||
parser.add_argument('--n_gpu', type=int, default=4, help='')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
_tok_dict = {"(": "-lrb-", ")": "-rrb-",
|
||||
"[": "-lsb-", "]": "-rsb-",
|
||||
"{": "-lcb-", "}": "-rcb-",
|
||||
"[UNK]": "UNK", '&': '&', '<': '<', '>': '>'}
|
||||
|
||||
def _is_digit(w):
|
||||
for ch in w:
|
||||
if not(ch.isdigit() or ch == ','):
|
||||
return False
|
||||
return True
|
||||
|
||||
def fix_tokenization(text):
|
||||
input_tokens = text.split()
|
||||
output_tokens = []
|
||||
has_left_quote = False
|
||||
has_left_single_quote = False
|
||||
|
||||
i = 0
|
||||
prev_dash = False
|
||||
while i < len(input_tokens):
|
||||
tok = input_tokens[i]
|
||||
flag_prev_dash = False
|
||||
if tok in _tok_dict.keys():
|
||||
output_tokens.append(_tok_dict[tok])
|
||||
i += 1
|
||||
elif tok == "\"":
|
||||
if has_left_quote:
|
||||
output_tokens.append("''")
|
||||
else:
|
||||
output_tokens.append("``")
|
||||
has_left_quote = not has_left_quote
|
||||
i += 1
|
||||
elif tok == "'" and len(output_tokens) > 0 and output_tokens[-1].endswith("n") and i < len(input_tokens) - 1 and input_tokens[i + 1] == "t":
|
||||
output_tokens[-1] = output_tokens[-1][:-1]
|
||||
output_tokens.append("n't")
|
||||
i += 2
|
||||
elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[i + 1] in ("s", "d", "ll"):
|
||||
output_tokens.append("'"+input_tokens[i + 1])
|
||||
i += 2
|
||||
elif tok == "'":
|
||||
if has_left_single_quote:
|
||||
output_tokens.append("'")
|
||||
else:
|
||||
output_tokens.append("`")
|
||||
has_left_single_quote = not has_left_single_quote
|
||||
i += 1
|
||||
elif tok == "." and i < len(input_tokens) - 2 and input_tokens[i + 1] == "." and input_tokens[i + 2] == ".":
|
||||
output_tokens.append("...")
|
||||
i += 3
|
||||
elif tok == "," and len(output_tokens) > 0 and _is_digit(output_tokens[-1]) and i < len(input_tokens) - 1 and _is_digit(input_tokens[i + 1]):
|
||||
# $ 3 , 000 -> $ 3,000
|
||||
output_tokens[-1] += ','+input_tokens[i + 1]
|
||||
i += 2
|
||||
elif tok == "." and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and input_tokens[i + 1].isdigit():
|
||||
# 3 . 03 -> $ 3.03
|
||||
output_tokens[-1] += '.'+input_tokens[i + 1]
|
||||
i += 2
|
||||
elif tok == "." and len(output_tokens) > 0 and len(output_tokens[-1]) == 1 and output_tokens[-1].isupper() and i < len(input_tokens) - 2 and len(input_tokens[i + 1]) == 1 and input_tokens[i + 1].isupper() and input_tokens[i + 2] == '.':
|
||||
# U . N . -> U.N.
|
||||
k = i+3
|
||||
while k+2 < len(input_tokens):
|
||||
if len(input_tokens[k + 1]) == 1 and input_tokens[k + 1].isupper() and input_tokens[k + 2] == '.':
|
||||
k += 2
|
||||
else:
|
||||
break
|
||||
output_tokens[-1] += ''.join(input_tokens[i:k])
|
||||
i += 2
|
||||
elif tok == "-":
|
||||
if i < len(input_tokens) - 1 and input_tokens[i + 1] == "-":
|
||||
output_tokens.append("--")
|
||||
i += 2
|
||||
elif i == len(input_tokens) - 1 or i == 0:
|
||||
output_tokens.append("-")
|
||||
i += 1
|
||||
elif output_tokens[-1] not in string.punctuation and input_tokens[i + 1][0] not in string.punctuation:
|
||||
output_tokens[-1] += "-"
|
||||
i += 1
|
||||
flag_prev_dash = True
|
||||
else:
|
||||
output_tokens.append("-")
|
||||
i += 1
|
||||
elif prev_dash and len(output_tokens) > 0 and tok[0] not in string.punctuation:
|
||||
output_tokens[-1] += tok
|
||||
i += 1
|
||||
else:
|
||||
output_tokens.append(tok)
|
||||
i += 1
|
||||
prev_dash = flag_prev_dash
|
||||
return " ".join(output_tokens)
|
||||
|
||||
def process_eval(args, gen_list, tgt_list, ori_gen_peace):
|
||||
evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], max_n=2,
|
||||
limit_length=False, apply_avg=True, weight_factor=1.2)
|
||||
|
||||
max_score = []
|
||||
avg_score = []
|
||||
lowest_score = []
|
||||
best_gen_list = []
|
||||
# max score,avg score,lowest score
|
||||
for index, sentence in enumerate(tqdm(gen_list)):
|
||||
if index % args.num_samples == 0:
|
||||
max_score_dict = {'rouge_1':0.0,'rouge_2':0.0,'rouge_l':0.0}
|
||||
avg_score_dict = {'rouge_1': 0.0, 'rouge_2': 0.0, 'rouge_l': 0.0}
|
||||
low_score_dict = {'rouge_1': 1.0, 'rouge_2': 1.0, 'rouge_l': 1.0}
|
||||
target = tgt_list[index // args.num_samples]
|
||||
scores = evaluator.get_scores([sentence], [[target]])
|
||||
rouge_1 = scores['rouge-1']['f']
|
||||
rouge_2 = scores['rouge-2']['f']
|
||||
rouge_l = scores['rouge-l']['f']
|
||||
# max
|
||||
if rouge_2 >= max_score_dict['rouge_2']:
|
||||
if rouge_2 != 0:
|
||||
max_score_dict['rouge_1'] = rouge_1
|
||||
max_score_dict['rouge_2'] = rouge_2
|
||||
max_score_dict['rouge_l'] = rouge_l
|
||||
best_sentence = ori_gen_peace[index]
|
||||
else:
|
||||
if rouge_1 >= max_score_dict['rouge_1']:
|
||||
max_score_dict['rouge_1'] = rouge_1
|
||||
max_score_dict['rouge_2'] = rouge_2
|
||||
max_score_dict['rouge_l'] = rouge_l
|
||||
best_sentence = ori_gen_peace[index]
|
||||
# avg
|
||||
avg_score_dict['rouge_1'] += rouge_1
|
||||
avg_score_dict['rouge_2'] += rouge_2
|
||||
avg_score_dict['rouge_l'] += rouge_l
|
||||
# min
|
||||
if rouge_2 < low_score_dict['rouge_2']:
|
||||
low_score_dict['rouge_1'] = rouge_1
|
||||
low_score_dict['rouge_2'] = rouge_2
|
||||
low_score_dict['rouge_l'] = rouge_l
|
||||
if (index + 1) % args.num_samples == 0:
|
||||
max_score.append(max_score_dict)
|
||||
best_gen_list.append(best_sentence)
|
||||
avg_score_dict['rouge_1'] = avg_score_dict['rouge_1'] / args.num_samples
|
||||
avg_score_dict['rouge_2'] = avg_score_dict['rouge_2'] / args.num_samples
|
||||
avg_score_dict['rouge_l'] = avg_score_dict['rouge_l'] / args.num_samples
|
||||
avg_score.append(avg_score_dict)
|
||||
lowest_score.append(low_score_dict)
|
||||
|
||||
rouge_1 = 0
|
||||
rouge_2 = 0
|
||||
rouge_l = 0
|
||||
for score_dict in max_score:
|
||||
rouge_1 += score_dict['rouge_1']
|
||||
rouge_2 += score_dict['rouge_2']
|
||||
rouge_l += score_dict['rouge_l']
|
||||
max_rouge_1 = rouge_1 / len(max_score)
|
||||
max_rouge_2 = rouge_2 / len(max_score)
|
||||
max_rouge_l = rouge_l / len(max_score)
|
||||
|
||||
rouge_1 = 0
|
||||
rouge_2 = 0
|
||||
rouge_l = 0
|
||||
for score_dict in avg_score:
|
||||
rouge_1 += score_dict['rouge_1']
|
||||
rouge_2 += score_dict['rouge_2']
|
||||
rouge_l += score_dict['rouge_l']
|
||||
avg_rouge_1 = rouge_1 / len(max_score)
|
||||
avg_rouge_2 = rouge_2 / len(max_score)
|
||||
avg_rouge_l = rouge_l / len(max_score)
|
||||
|
||||
rouge_1 = 0
|
||||
rouge_2 = 0
|
||||
rouge_l = 0
|
||||
for score_dict in lowest_score:
|
||||
rouge_1 += score_dict['rouge_1']
|
||||
rouge_2 += score_dict['rouge_2']
|
||||
rouge_l += score_dict['rouge_l']
|
||||
min_rouge_1 = rouge_1 / len(max_score)
|
||||
min_rouge_2 = rouge_2 / len(max_score)
|
||||
min_rouge_l = rouge_l / len(max_score)
|
||||
|
||||
scores = {'max_rouge_1':max_rouge_1, 'max_rouge_2':max_rouge_2, 'max_rouge_l':max_rouge_l,
|
||||
'min_rouge_1':min_rouge_1, 'min_rouge_2':min_rouge_2, 'min_rouge_l':min_rouge_l,
|
||||
'avg_rouge_1':avg_rouge_1, 'avg_rouge_2':avg_rouge_2, 'avg_rouge_l':avg_rouge_l }
|
||||
|
||||
return scores, best_gen_list
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
args = get_arguments()
|
||||
tgt = []
|
||||
test_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/test.tgt")
|
||||
with open(test_tgt_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
line = line.strip().replace(" <S_SEP> ", '\n')
|
||||
text = line
|
||||
tgt.append(text)
|
||||
|
||||
final_scores = {'max_rouge_1': 0.0, 'max_rouge_2': 0.0, 'max_rouge_l': 0.0,
|
||||
'min_rouge_1': 0.0, 'min_rouge_2': 0.0, 'min_rouge_l': 0.0,
|
||||
'avg_rouge_1': 0.0, 'avg_rouge_2': 0.0, 'avg_rouge_l': 0.0}
|
||||
tgt_offset = 0
|
||||
best_gen_list_total = []
|
||||
|
||||
for i in range(args.n_gpu):
|
||||
|
||||
gen_list_total = []
|
||||
ori_gen_list_total = []
|
||||
for epoch in range(args.num_samples):
|
||||
gen_list = []
|
||||
ori_gen_list = []
|
||||
gen_text_path = os.path.join(args.generate_path, "rank" + str(i)+"_gen_seed_" + str(args.seed) +
|
||||
"_num" + str(args.num_samples) + "_epoch" + str(epoch+1) + ".txt")
|
||||
with open(gen_text_path, "r", encoding="utf-8") as ifile:
|
||||
for line in tqdm(ifile):
|
||||
ori_gen_list.append(line.strip())
|
||||
line = line.strip().split(' < s _ sep > ')
|
||||
line = [fix_tokenization(sen) for sen in line]
|
||||
text = "\n".join(line)
|
||||
# text = line
|
||||
gen_list.append(text)
|
||||
gen_list_total.append(gen_list)
|
||||
ori_gen_list_total.append(ori_gen_list)
|
||||
|
||||
gen_peace = []
|
||||
for index in range(len(gen_list_total[0])):
|
||||
for i in range(args.num_samples):
|
||||
gen_peace.append(gen_list_total[i][index])
|
||||
|
||||
ori_gen_peace = []
|
||||
for index in range(len(ori_gen_list_total[0])):
|
||||
for i in range(args.num_samples):
|
||||
ori_gen_peace.append(ori_gen_list_total[i][index])
|
||||
|
||||
gen_len = len(gen_peace) // args.num_samples
|
||||
scores, best_gen_list = process_eval(args, gen_peace, tgt[tgt_offset:tgt_offset+gen_len], ori_gen_peace)
|
||||
best_gen_list_total.extend(best_gen_list)
|
||||
for key, values in scores.items():
|
||||
final_scores[key] += values
|
||||
print("scores on gpu ", i)
|
||||
print(scores)
|
||||
tgt_offset += gen_len
|
||||
for key, values in final_scores.items():
|
||||
final_scores[key] = values / args.n_gpu
|
||||
|
||||
print("final score :")
|
||||
print(final_scores)
|
||||
|
||||
assert len(best_gen_list_total) == len(tgt)
|
||||
print("store best gen list ...")
|
||||
out_path = os.path.join(args.generate_path, "best_gen_list.txt")
|
||||
with open(out_path, 'w') as f:
|
||||
for sentence in best_gen_list_total:
|
||||
f.write(sentence + '\n')
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,372 @@
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
) # is self-attn if context is none
|
||||
|
||||
# layer norms
|
||||
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
||||
if self.use_ada_layer_norm:
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
self.attn1._slice_size = slice_size
|
||||
self.attn2._slice_size = slice_size
|
||||
|
||||
# def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
# if not is_xformers_available():
|
||||
# print("Here is how to install it")
|
||||
# raise ModuleNotFoundError(
|
||||
# "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
# " xformers",
|
||||
# name="xformers",
|
||||
# )
|
||||
# elif not torch.cuda.is_available():
|
||||
# raise ValueError(
|
||||
# "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
||||
# " available for GPU "
|
||||
# )
|
||||
# else:
|
||||
# try:
|
||||
# # Make sure we can run the memory efficient attention
|
||||
# _ = xformers.ops.memory_efficient_attention(
|
||||
# torch.randn((1, 2, 40), device="cuda"),
|
||||
# torch.randn((1, 2, 40), device="cuda"),
|
||||
# torch.randn((1, 2, 40), device="cuda"),
|
||||
# )
|
||||
# except Exception as e:
|
||||
# raise e
|
||||
# self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
# self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
|
||||
'''
|
||||
input value
|
||||
hidden_state: encode from query (batch_size, query_seq_len, hidden_size)
|
||||
context: encode from passage (batch_size, passage_seq_len, hidden_size)
|
||||
|
||||
output value
|
||||
hidden_states : (batch_size, query_seq_len, hidden_size)
|
||||
'''
|
||||
def forward(self, hidden_states, context=None):
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm1(hidden_states)
|
||||
)
|
||||
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states)
|
||||
)
|
||||
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
r"""
|
||||
A cross attention layer.
|
||||
|
||||
Parameters:
|
||||
query_dim (`int`): The number of channels in the query.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the context. If not given, defaults to `query_dim`.
|
||||
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
||||
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
bias (`bool`, *optional*, defaults to False):
|
||||
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias=False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
# for slice_size > 0 the attention score computation
|
||||
# is split across the batch axis to save memory
|
||||
# You can set slice_size with `set_attention_slice`
|
||||
self._slice_size = None
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def reshape_batch_dim_to_heads(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def forward(self, hidden_states, context=None, mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
context = context if context is not None else hidden_states
|
||||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
|
||||
dim = query.shape[-1]
|
||||
|
||||
query = self.reshape_heads_to_batch_dim(query)
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
|
||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
|
||||
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
else:
|
||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value)
|
||||
else:
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _attention(self, query, key, value):
|
||||
# TODO: use baddbmm for better performance
|
||||
if query.device.type == "mps":
|
||||
# Better performance on mps (~20-25%)
|
||||
attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
|
||||
else:
|
||||
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
|
||||
attention_probs = attention_scores.softmax(dim=-1)
|
||||
# compute attention output
|
||||
|
||||
if query.device.type == "mps":
|
||||
hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
|
||||
else:
|
||||
hidden_states = torch.matmul(attention_probs, value)
|
||||
|
||||
# reshape hidden_states
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _sliced_attention(self, query, key, value, sequence_length, dim):
|
||||
batch_size_attention = query.shape[0]
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
||||
for i in range(hidden_states.shape[0] // slice_size):
|
||||
start_idx = i * slice_size
|
||||
end_idx = (i + 1) * slice_size
|
||||
if query.device.type == "mps":
|
||||
# Better performance on mps (~20-25%)
|
||||
attn_slice = (
|
||||
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
|
||||
* self.scale
|
||||
)
|
||||
else:
|
||||
attn_slice = (
|
||||
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
|
||||
) # TODO: use baddbmm for better performance
|
||||
attn_slice = attn_slice.softmax(dim=-1)
|
||||
if query.device.type == "mps":
|
||||
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
|
||||
else:
|
||||
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
# reshape hidden_states
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
# def _memory_efficient_attention_xformers(self, query, key, value):
|
||||
# query = query.contiguous()
|
||||
# key = key.contiguous()
|
||||
# value = value.contiguous()
|
||||
# hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
|
||||
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
# return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input.
|
||||
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
||||
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
if activation_fn == "geglu":
|
||||
geglu = GEGLU(dim, inner_dim)
|
||||
elif activation_fn == "geglu-approximate":
|
||||
geglu = ApproximateGELU(dim, inner_dim)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
self.net.append(geglu)
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
r"""
|
||||
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
||||
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def gelu(self, gate):
|
||||
if gate.device.type != "mps":
|
||||
return F.gelu(gate)
|
||||
# mps: gelu is not implemented for float16
|
||||
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||
return hidden_states * self.gelu(gate)
|
||||
|
||||
|
||||
class ApproximateGELU(nn.Module):
|
||||
"""
|
||||
The approximate form of Gaussian Error Linear Unit (GELU)
|
||||
|
||||
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, num_embeddings):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
||||
|
||||
def forward(self, x, timestep):
|
||||
emb = self.linear(self.silu(self.emb(timestep)))
|
||||
scale, shift = torch.chunk(emb, 2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
|
@ -0,0 +1,372 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
import math
|
||||
from transformers import (
|
||||
BertModel,
|
||||
BertConfig,
|
||||
)
|
||||
from model.CrossAttentionTransformers import BasicTransformerBlock
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
class Diffusion_LM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
dropout=0,
|
||||
config=None,
|
||||
config_name='bert-base-uncased',
|
||||
vocab_size=None,
|
||||
init_pretrained=True,
|
||||
logits_mode=1,
|
||||
token_emb_type='pretrain',
|
||||
# num_heads=1,
|
||||
# channel_mult=(1, 2, 4, 8),
|
||||
# use_scale_shift_norm=False,
|
||||
# training_mode='emb',
|
||||
# experiment_mode='lm',
|
||||
# num_heads_upsample=-1,
|
||||
# use_checkpoint=False,
|
||||
# num_classes=None,
|
||||
# conv_resample=True,
|
||||
# attention_resolutions,
|
||||
# num_res_blocks,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.dropout = dropout
|
||||
self.logits_mode = logits_mode
|
||||
self.init_pretrained = init_pretrained
|
||||
self.token_emb_type = token_emb_type
|
||||
|
||||
config = BertConfig.from_pretrained(config_name)
|
||||
config.hidden_dropout_prob = self.dropout
|
||||
print(config)
|
||||
|
||||
# 可训练的 embedding 层
|
||||
self.word_embedding = nn.Embedding(vocab_size, self.in_channels)
|
||||
# position embedding
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
if self.token_emb_type == 'pretrain':
|
||||
temp_bert = BertModel.from_pretrained(config_name, config=config)
|
||||
self.word_embedding.weight = temp_bert.embeddings.word_embeddings.weight
|
||||
self.position_embeddings.weight = temp_bert.embeddings.position_embeddings.weight
|
||||
elif self.token_emb_type == 'random':
|
||||
print("load embedding weight random")
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
if self.logits_mode == 2:
|
||||
# self.lm_head = nn.Linear(self.in_channels, vocab_size, bias=False)
|
||||
self.lm_head = nn.Linear(self.in_channels, vocab_size, bias=True)
|
||||
else:
|
||||
self.lm_head = nn.Linear(self.in_channels, vocab_size)
|
||||
|
||||
# share weight between lm_head and word_embedding
|
||||
with torch.no_grad():
|
||||
self.lm_head.weight = self.word_embedding.weight
|
||||
|
||||
# self.word_embedding = nn.Embedding(vocab_size, self.in_channels)
|
||||
# self.lm_head = nn.Linear(self.in_channels, vocab_size)
|
||||
# with th.no_grad():
|
||||
# self.lm_head.weight = self.word_embedding.weight
|
||||
|
||||
# time embedding
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
nn.Linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_embed_dim, config.hidden_size),
|
||||
)
|
||||
|
||||
# position embedding
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
# # label embedding
|
||||
# if self.num_classes is not None:
|
||||
# self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
|
||||
# input transform
|
||||
self.input_up_proj = nn.Sequential(
|
||||
nn.Linear(in_channels, config.hidden_size),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.hidden_size, config.hidden_size)
|
||||
)
|
||||
|
||||
# Dropout
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
'''
|
||||
Diffusion Transformer (6 layer)
|
||||
'''
|
||||
if self.init_pretrained:
|
||||
temp_bert = BertModel.from_pretrained(config_name, config=config)
|
||||
del temp_bert.embeddings
|
||||
del temp_bert.pooler
|
||||
self.input_transformers = temp_bert.encoder
|
||||
print('initializing from pretrained bert.')
|
||||
else:
|
||||
temp_bert = BertModel(config)
|
||||
self.input_transformers = temp_bert.encoder
|
||||
print('initializing from random bert.')
|
||||
|
||||
# output transform
|
||||
self.output_down_proj = nn.Sequential(
|
||||
nn.Linear(config.hidden_size, config.hidden_size),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.hidden_size, out_channels)
|
||||
)
|
||||
|
||||
def get_embeds(self, input_ids):
|
||||
return self.word_embedding(input_ids)
|
||||
|
||||
def get_logits(self, hidden_repr):
|
||||
if self.logits_mode == 1:
|
||||
return self.lm_head(hidden_repr)
|
||||
elif self.logits_mode == 2:
|
||||
text_emb = hidden_repr
|
||||
emb_norm = (self.lm_head.weight ** 2).sum(-1).view(-1, 1) # vocab
|
||||
text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
|
||||
arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
|
||||
dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(self.lm_head.weight,
|
||||
text_emb_t) # (vocab, d) x (d, bsz*seqlen)
|
||||
scores = torch.sqrt(torch.clamp(dist, 0.0, np.inf)).view(emb_norm.size(0), hidden_repr.size(0),
|
||||
hidden_repr.size(1)) # vocab, bsz*seqlen
|
||||
scores = -scores.permute(1, 2, 0).contiguous()
|
||||
|
||||
#
|
||||
# scores1 = th.cdist(self.lm_head.weight.unsqueeze(0), hidden_repr, p=2)
|
||||
# scores1 = -scores1.permute(0, 2, 1).contiguous()
|
||||
#
|
||||
# print(scores1.shape, scores.shape)
|
||||
# print(scores1[0,0], scores[0,0])
|
||||
# print(torch.isclose(scores1, scores))
|
||||
|
||||
return scores
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x, timesteps, attention_mask=None, y=None, src_ids=None, src_mask=None):
|
||||
|
||||
# prepare embedding
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
emb_x = self.input_up_proj(x)
|
||||
seq_length = x.size(1)
|
||||
position_ids = self.position_ids[:, : seq_length]
|
||||
# print(emb_x.shape, emb.shape, self.position_embeddings)
|
||||
emb_inputs = self.position_embeddings(position_ids) + emb_x + emb.unsqueeze(1).expand(-1, seq_length, -1)
|
||||
emb_inputs = self.dropout(self.LayerNorm(emb_inputs))
|
||||
|
||||
# encode embedding
|
||||
# print(emb_inputs.shape, attention_mask.shape)
|
||||
input_trans_hidden_states = self.input_transformers(emb_inputs, attention_mask=attention_mask).last_hidden_state
|
||||
h = self.output_down_proj(input_trans_hidden_states)
|
||||
h = h.type(x.dtype)
|
||||
return h
|
||||
|
||||
|
||||
|
||||
class CrossAttention_Diffusion_LM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
dropout=0,
|
||||
config=None,
|
||||
config_name='bert-base-uncased',
|
||||
vocab_size=None,
|
||||
init_pretrained=True,
|
||||
logits_mode=1,
|
||||
token_emb_type='pretrain',
|
||||
fix_encoder=False
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.dropout = dropout
|
||||
self.logits_mode = logits_mode
|
||||
self.init_pretrained = init_pretrained
|
||||
self.token_emb_type = token_emb_type
|
||||
self.fix_encoder = fix_encoder
|
||||
|
||||
cfg = BertConfig.from_pretrained(config_name)
|
||||
cfg.num_hidden_layers = 6
|
||||
self.passage_encoder = BertModel.from_pretrained(config_name, config=cfg)
|
||||
# self.passage_encoder = BertModel.from_pretrained(
|
||||
# "/colab_space/Lin0/PROD/KDexp/pretrain_model/bert-base-uncased", config=cfg)
|
||||
|
||||
|
||||
|
||||
|
||||
config = BertConfig.from_pretrained(config_name)
|
||||
config.hidden_dropout_prob = self.dropout
|
||||
print(config)
|
||||
|
||||
# trainable embedding layer
|
||||
self.word_embedding = nn.Embedding(vocab_size, self.in_channels)
|
||||
# position embedding
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
if self.logits_mode == 2:
|
||||
# self.lm_head = nn.Linear(self.in_channels, vocab_size, bias=False)
|
||||
self.lm_head = nn.Linear(self.in_channels, vocab_size, bias=True)
|
||||
else:
|
||||
self.lm_head = nn.Linear(self.in_channels, vocab_size)
|
||||
|
||||
# share weight between lm_head and word_embedding
|
||||
with torch.no_grad():
|
||||
self.lm_head.weight = self.word_embedding.weight
|
||||
|
||||
# self.word_embedding = nn.Embedding(vocab_size, self.in_channels)
|
||||
# self.lm_head = nn.Linear(self.in_channels, vocab_size)
|
||||
# with th.no_grad():
|
||||
# self.lm_head.weight = self.word_embedding.weight
|
||||
|
||||
# time embedding layer
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
nn.Linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(time_embed_dim, config.hidden_size),
|
||||
)
|
||||
|
||||
# position embedding
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
# # label embedding
|
||||
# if self.num_classes is not None:
|
||||
# self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
|
||||
# input transform
|
||||
self.input_up_proj = nn.Sequential(
|
||||
nn.Linear(in_channels, config.hidden_size),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.hidden_size, config.hidden_size)
|
||||
)
|
||||
|
||||
# Dropout
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
config.num_hidden_layers = 6
|
||||
# define cross attention transformer block(6 layer)
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
attention_head_dim=config.hidden_size // config.num_attention_heads,
|
||||
dropout=config.hidden_dropout_prob,
|
||||
cross_attention_dim=config.hidden_size,
|
||||
activation_fn="geglu",
|
||||
)
|
||||
for d in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# output transform
|
||||
self.output_down_proj = nn.Sequential(
|
||||
nn.Linear(config.hidden_size, config.hidden_size),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.hidden_size, out_channels)
|
||||
)
|
||||
|
||||
def get_embeds(self, input_ids):
|
||||
return self.word_embedding(input_ids)
|
||||
|
||||
def get_logits(self, hidden_repr):
|
||||
if self.logits_mode == 1:
|
||||
return self.lm_head(hidden_repr)
|
||||
elif self.logits_mode == 2:
|
||||
text_emb = hidden_repr
|
||||
emb_norm = (self.lm_head.weight ** 2).sum(-1).view(-1, 1) # vocab
|
||||
text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
|
||||
arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
|
||||
dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(self.lm_head.weight,
|
||||
text_emb_t) # (vocab, d) x (d, bsz*seqlen)
|
||||
scores = torch.sqrt(torch.clamp(dist, 0.0, np.inf)).view(emb_norm.size(0), hidden_repr.size(0),
|
||||
hidden_repr.size(1)) # vocab, bsz*seqlen
|
||||
scores = -scores.permute(1, 2, 0).contiguous()
|
||||
|
||||
return scores
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x, timesteps, src_input_ids, src_attention_mask, attention_mask=None,
|
||||
answer_id=None, answer_mask=None, y=None, src_ids=None, src_mask=None):
|
||||
|
||||
# prepare embedding
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
emb_x = self.input_up_proj(x)
|
||||
seq_length = x.size(1)
|
||||
position_ids = self.position_ids[:, : seq_length]
|
||||
# print(emb_x.shape, emb.shape, self.position_embeddings)
|
||||
emb_inputs = self.position_embeddings(position_ids) + emb_x + emb.unsqueeze(1).expand(-1, seq_length, -1)
|
||||
hidden_states = self.dropout(self.LayerNorm(emb_inputs))
|
||||
# encode embedding
|
||||
# print(emb_inputs.shape, attention_mask.shape)
|
||||
if self.fix_encoder:
|
||||
with torch.no_grad():
|
||||
out = self.passage_encoder(input_ids=src_input_ids,
|
||||
attention_mask=src_attention_mask)
|
||||
passage_hidden = out.last_hidden_state
|
||||
else:
|
||||
out = self.passage_encoder(input_ids=src_input_ids,
|
||||
attention_mask=src_attention_mask)
|
||||
passage_hidden = out.last_hidden_state + 0 * out.pooler_output.unsqueeze(1)
|
||||
|
||||
if answer_id is not None:
|
||||
answer_hidden_states = hidden_states.clone()
|
||||
answer_out = self.passage_encoder(input_ids=answer_id,
|
||||
attention_mask=answer_mask)
|
||||
answer_hidden = answer_out.last_hidden_state + 0 * answer_out.pooler_output.unsqueeze(1)
|
||||
for block in self.transformer_blocks:
|
||||
answer_hidden_states = block(answer_hidden_states, answer_hidden)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, passage_hidden)
|
||||
|
||||
if answer_id is not None:
|
||||
# print("model_qg_forward...")
|
||||
hidden_states = hidden_states + answer_hidden_states
|
||||
|
||||
h = self.output_down_proj(hidden_states)
|
||||
h = h.type(x.dtype)
|
||||
return h
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
"""
|
||||
Helpers for distributed training.
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import socket
|
||||
|
||||
import blobfile as bf
|
||||
from mpi4py import MPI
|
||||
import torch as th
|
||||
import torch.distributed as dist
|
||||
|
||||
# Change this to reflect your cluster layout.
|
||||
# The GPU for a given rank is (rank % GPUS_PER_NODE).
|
||||
GPUS_PER_NODE = 1 #8
|
||||
|
||||
SETUP_RETRY_COUNT = 3
|
||||
|
||||
|
||||
def setup_dist():
|
||||
"""
|
||||
Setup a distributed process group.
|
||||
"""
|
||||
if dist.is_initialized():
|
||||
return
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
backend = "gloo" if not th.cuda.is_available() else "nccl"
|
||||
|
||||
if backend == "gloo":
|
||||
hostname = "localhost"
|
||||
else:
|
||||
hostname = socket.gethostbyname(socket.getfqdn())
|
||||
os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
|
||||
os.environ["RANK"] = str(comm.rank)
|
||||
os.environ["WORLD_SIZE"] = str(comm.size)
|
||||
|
||||
port = comm.bcast(_find_free_port(), root=0)
|
||||
os.environ["MASTER_PORT"] = str(port)
|
||||
dist.init_process_group(backend=backend, init_method="env://")
|
||||
|
||||
|
||||
def dev():
|
||||
"""
|
||||
Get the device to use for torch.distributed.
|
||||
"""
|
||||
if th.cuda.is_available():
|
||||
return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
|
||||
return th.device("cpu")
|
||||
|
||||
|
||||
def load_state_dict(path, **kwargs):
|
||||
"""
|
||||
Load a PyTorch file without redundant fetches across MPI ranks.
|
||||
"""
|
||||
if MPI.COMM_WORLD.Get_rank() == 0:
|
||||
with bf.BlobFile(path, "rb") as f:
|
||||
data = f.read()
|
||||
else:
|
||||
data = None
|
||||
data = MPI.COMM_WORLD.bcast(data)
|
||||
return th.load(io.BytesIO(data), **kwargs)
|
||||
|
||||
|
||||
def sync_params(params):
|
||||
"""
|
||||
Synchronize a sequence of Tensors across ranks from rank 0.
|
||||
"""
|
||||
for p in params:
|
||||
with th.no_grad():
|
||||
dist.broadcast(p, 0)
|
||||
|
||||
|
||||
def _find_free_port():
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.bind(("", 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
return s.getsockname()[1]
|
||||
finally:
|
||||
s.close()
|
|
@ -0,0 +1,375 @@
|
|||
import torch
|
||||
import copy
|
||||
import os
|
||||
from torch import nn
|
||||
import collections
|
||||
from util import logger
|
||||
from train_util import dist_util
|
||||
from transformers import AdamW
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.serialization import default_restore_location
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from transformers import (
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from data_util.pretrain_data_util import load_loop_pretrain_data
|
||||
|
||||
# from text_data_util import Text_Hidden_dataset, Question_dataset, PandQ_dataset
|
||||
# from s2s_data_util import S2S_dataset
|
||||
from data_util.pretrain_data_util import Pre_dataset, Pre_dataset_type2
|
||||
from diffusion_util.resample import LossAwareSampler, UniformSampler
|
||||
|
||||
INITIAL_LOG_LOSS_SCALE = 20.0
|
||||
CheckpointState = collections.namedtuple("CheckpointState",
|
||||
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset'])
|
||||
|
||||
'''
|
||||
TrainLoop
|
||||
'''
|
||||
class PretrainLoop:
|
||||
def __init__(
|
||||
self,
|
||||
train_type,
|
||||
model,
|
||||
diffusion,
|
||||
data,
|
||||
batch_size,
|
||||
lr,
|
||||
ema_rate,
|
||||
log_interval,
|
||||
save_interval,
|
||||
resume_checkpoint,
|
||||
warmup_steps=0,
|
||||
use_fp16=False,
|
||||
fp16_scale_growth=1e-3,
|
||||
schedule_sampler=None,
|
||||
weight_decay=0.0,
|
||||
lr_anneal_steps=0,
|
||||
checkpoint_path='',
|
||||
gradient_clipping=-1.,
|
||||
eval_data=None,
|
||||
eval_interval=-1,
|
||||
gradient_accumulation_steps=1,
|
||||
device=None,
|
||||
args=None,
|
||||
tokenizer=None
|
||||
):
|
||||
self.train_type = train_type
|
||||
self.model = model
|
||||
self.diffusion = diffusion
|
||||
self.data = data
|
||||
self.eval_data = eval_data
|
||||
self.batch_size = batch_size
|
||||
self.lr = lr
|
||||
self.ema_rate = (
|
||||
[ema_rate]
|
||||
if isinstance(ema_rate, float)
|
||||
else [float(x) for x in ema_rate.split(",")]
|
||||
)
|
||||
self.log_interval = log_interval
|
||||
self.eval_interval = eval_interval
|
||||
self.save_interval = save_interval
|
||||
self.resume_checkpoint = resume_checkpoint
|
||||
self.warmup_steps = warmup_steps
|
||||
self.use_fp16 = use_fp16
|
||||
self.fp16_scale_growth = fp16_scale_growth
|
||||
self.schedule_sampler = schedule_sampler
|
||||
self.weight_decay = weight_decay
|
||||
self.lr_anneal_steps = lr_anneal_steps
|
||||
self.gradient_clipping = gradient_clipping
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.device = device
|
||||
self.args = args
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# self.global_batch = self.batch_size * dist.get_world_size()
|
||||
|
||||
self.master_params = list(self.model.parameters())
|
||||
self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
|
||||
# self.sync_cuda = th.cuda.is_available()
|
||||
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.optimizer = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
|
||||
|
||||
self.total_step = 5000000
|
||||
self.scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.total_step
|
||||
)
|
||||
self.global_step = 0
|
||||
|
||||
if self.checkpoint_path is not None:
|
||||
model_checkpoint_files = []
|
||||
ema_checkpoint_files = []
|
||||
if os.path.exists(self.checkpoint_path):
|
||||
for item in os.scandir(self.checkpoint_path):
|
||||
if item.is_file():
|
||||
if "model_checkpoint" in item.path:
|
||||
model_checkpoint_files.append(item.path)
|
||||
if "ema" in item.path:
|
||||
ema_checkpoint_files.append(item.path)
|
||||
if len(model_checkpoint_files) != 0 and len(ema_checkpoint_files) != 0:
|
||||
model_checkpoint_files.sort(key=lambda f: int(f.split('model_checkpoint-')[1]), reverse=True)
|
||||
logger.info("***** load " + model_checkpoint_files[0] + " *****")
|
||||
ema_checkpoint_files.sort(key=lambda f: int(f.split('checkpoint-')[-1]), reverse=True)
|
||||
logger.info("***** load " + ema_checkpoint_files[0] + " *****")
|
||||
model_saved_state = load_states_from_checkpoint(model_checkpoint_files[0])
|
||||
self.global_step = self._load_saved_state(model_saved_state)
|
||||
self.ema_params = [
|
||||
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
|
||||
]
|
||||
else:
|
||||
self.ema_params = [
|
||||
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
|
||||
]
|
||||
logger.info("***** there are no checkpoint in" + self.checkpoint_path + " *****")
|
||||
|
||||
else:
|
||||
self.ema_params = [
|
||||
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
|
||||
]
|
||||
|
||||
# model to DDP
|
||||
if dist.get_world_size() > 1:
|
||||
self.model = DDP(
|
||||
self.model, device_ids=[dist.get_rank()], output_device=dist.get_rank(), find_unused_parameters=False,
|
||||
)
|
||||
else:
|
||||
print("single GPU is not achieve now")
|
||||
exit(0)
|
||||
|
||||
def run_loop(self):
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Max steps = %d", self.total_step)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", self.batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
self.batch_size
|
||||
* self.gradient_accumulation_steps
|
||||
* (dist.get_world_size()),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", self.gradient_accumulation_steps)
|
||||
self.model.zero_grad()
|
||||
self.model.train()
|
||||
|
||||
while self.global_step < self.total_step:
|
||||
|
||||
for data_name in self.data:
|
||||
|
||||
print("pretraining diffusion using data :", data_name)
|
||||
|
||||
train_data = load_loop_pretrain_data(
|
||||
self.args,
|
||||
padding_mode='conti_tgt',
|
||||
tokenizer=self.tokenizer,
|
||||
data_name=data_name,
|
||||
)
|
||||
|
||||
# ddp data sample
|
||||
train_sample = DistributedSampler(train_data)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sample, batch_size=self.batch_size, drop_last=False,
|
||||
num_workers=20, collate_fn=Pre_dataset_type2.get_collate_fn())
|
||||
|
||||
# while self.global_step < self.lr_anneal_steps:
|
||||
|
||||
# training for one epoch
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=dist.get_rank() not in [-1, 0])
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
self.model.train()
|
||||
|
||||
# forward loss
|
||||
self.forward_backward(batch)
|
||||
if self.use_fp16:
|
||||
pass
|
||||
else:
|
||||
if (step + 1) % self.gradient_accumulation_steps == 0:
|
||||
# gradient clip
|
||||
if self.gradient_clipping > 0:
|
||||
self.grad_clip()
|
||||
self._log_grad_norm()
|
||||
self.optimizer.step()
|
||||
# lr scheduler
|
||||
# self._anneal_lr()
|
||||
self.scheduler.step()
|
||||
self.model.zero_grad()
|
||||
for rate, params in zip(self.ema_rate, self.ema_params):
|
||||
self.update_ema(params, self.master_params, rate=rate)
|
||||
self.global_step += 1
|
||||
self.log_step()
|
||||
|
||||
if (self.global_step + 1) % self.log_interval == 0:
|
||||
logger.dumpkvs()
|
||||
# dist.barrier()
|
||||
if (self.global_step + 1) % self.save_interval == 0:
|
||||
self.save()
|
||||
|
||||
# self.save()
|
||||
if self.global_step > self.total_step:
|
||||
break
|
||||
|
||||
def save(self):
|
||||
|
||||
def save_checkpoint(rate, ema_params):
|
||||
model_to_save = get_model_obj(self.model)
|
||||
if not rate:
|
||||
model_state_dict = model_to_save.state_dict()
|
||||
else:
|
||||
model_state_dict = model_to_save.state_dict()
|
||||
for i, (name, _value) in enumerate(model_to_save.named_parameters()):
|
||||
assert name in model_state_dict
|
||||
model_state_dict[name] = ema_params[i]
|
||||
|
||||
opt_state_dict = self.optimizer.state_dict()
|
||||
sch_state_dict = self.scheduler.state_dict()
|
||||
offset = self.global_step
|
||||
state = CheckpointState(model_state_dict,
|
||||
opt_state_dict,
|
||||
sch_state_dict,
|
||||
offset,
|
||||
)
|
||||
if not rate:
|
||||
ckpt_path = os.path.join(self.checkpoint_path, 'model_checkpoint-' + str(offset))
|
||||
else:
|
||||
ckpt_path = os.path.join(self.checkpoint_path, 'ema_' + str(rate) + '_checkpoint-' + str(offset))
|
||||
|
||||
torch.save(state._asdict(), ckpt_path)
|
||||
logger.info('Saved checkpoint at %s', ckpt_path)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
save_checkpoint(0, None)
|
||||
for rate, params in zip(self.ema_rate, self.ema_params):
|
||||
save_checkpoint(rate, params)
|
||||
|
||||
# dist.barrier()
|
||||
|
||||
|
||||
def forward_backward(self, batch):
|
||||
|
||||
t, weights = self.schedule_sampler.sample(batch['src_input_ids'].shape[0], self.device)
|
||||
# print("src_input_ids shape:", batch['src_input_ids'].shape)
|
||||
# print("tgt_input_ids shape:", batch['tgt_input_ids'].shape)
|
||||
losses = self.diffusion.training_losses(self.model, batch, t)
|
||||
|
||||
if isinstance(self.schedule_sampler, LossAwareSampler):
|
||||
self.schedule_sampler.update_with_local_losses(
|
||||
t, losses["loss"].detach()
|
||||
)
|
||||
|
||||
loss = (losses["loss"] * weights).mean()
|
||||
if self.gradient_accumulation_steps > 1:
|
||||
loss = loss / self.gradient_accumulation_steps
|
||||
log_loss_dict(
|
||||
self.diffusion, t, {k: v * weights for k, v in losses.items()}
|
||||
)
|
||||
|
||||
if self.use_fp16:
|
||||
loss_scale = 2 ** self.lg_loss_scale
|
||||
(loss * loss_scale).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
def forward_only(self, batch):
|
||||
with torch.no_grad():
|
||||
self.model.zero_grad()
|
||||
if self.train_type == 'LM_Diffusion':
|
||||
t, weights = self.schedule_sampler.sample(batch[1].shape[0], dist_util.dev())
|
||||
inputs_text = {"query_ids": batch[1].long().to(self.device),
|
||||
"attention_mask_q": batch[2].long().to(self.device)}
|
||||
|
||||
losses = self.diffusion.training_losses(self.model, inputs_text, t)
|
||||
elif self.train_type == 'S2S_Diffusion':
|
||||
'''
|
||||
for s2s
|
||||
'''
|
||||
t, weights = self.schedule_sampler.sample(batch['src_input_ids'].shape[0], self.device)
|
||||
losses = self.diffusion.training_losses(self.model, batch, t)
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
log_loss_dict(
|
||||
self.diffusion, t, {f"eval_{k}": v * weights for k, v in losses.items()}
|
||||
)
|
||||
|
||||
def _log_grad_norm(self):
|
||||
sqsum = 0.0
|
||||
for p in self.master_params:
|
||||
# print(p)
|
||||
sqsum += (p.grad ** 2).sum().item()
|
||||
logger.logkv_mean("grad_norm", np.sqrt(sqsum))
|
||||
|
||||
def log_step(self):
|
||||
logger.logkv("step", self.global_step)
|
||||
if self.use_fp16:
|
||||
logger.logkv("lg_loss_scale", self.lg_loss_scale)
|
||||
|
||||
def _anneal_lr(self):
|
||||
if not self.lr_anneal_steps:
|
||||
return
|
||||
frac_done = self.global_step / self.lr_anneal_steps
|
||||
lr = self.lr * (1 - frac_done)
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
def grad_clip(self):
|
||||
# print('doing gradient clipping')
|
||||
max_grad_norm=self.gradient_clipping #3.0
|
||||
if hasattr(self.optimizer, "clip_grad_norm"):
|
||||
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
|
||||
self.optimizer.clip_grad_norm(max_grad_norm)
|
||||
else:
|
||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), #amp.master_params(self.opt) if self.use_apex else
|
||||
max_grad_norm,
|
||||
)
|
||||
|
||||
def _load_saved_state(self, saved_state: CheckpointState):
|
||||
self.global_step = saved_state.offset
|
||||
logger.info('Loading checkpoint @ step=%s', self.global_step)
|
||||
|
||||
logger.info('Loading saved model state ...')
|
||||
self.model.load_state_dict(saved_state.model_dict) # set strict=False if you use extra projection
|
||||
self.optimizer.load_state_dict(saved_state.optimizer_dict)
|
||||
self.scheduler.load_state_dict(saved_state.scheduler_dict)
|
||||
self.master_params = list(self.model.parameters())
|
||||
return self.global_step
|
||||
|
||||
|
||||
def update_ema(self, target_params, source_params, rate=0.99):
|
||||
"""
|
||||
Update target parameters to be closer to those of source parameters using
|
||||
an exponential moving average.
|
||||
|
||||
:param target_params: the target parameter sequence.
|
||||
:param source_params: the source parameter sequence.
|
||||
:param rate: the EMA rate (closer to 1 means slower).
|
||||
"""
|
||||
for targ, src in zip(target_params, source_params):
|
||||
# print("target_params:", targ.device)
|
||||
# print("source_params:", src.device)
|
||||
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
||||
|
||||
|
||||
|
||||
def log_loss_dict(diffusion, ts, losses):
|
||||
for key, values in losses.items():
|
||||
logger.logkv_mean(key, values.mean().item())
|
||||
# Log the quantiles (four quartiles, in particular).
|
||||
for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
|
||||
quartile = int(4 * sub_t / diffusion.num_timesteps)
|
||||
logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
|
||||
|
||||
def get_model_obj(model: nn.Module):
|
||||
return model.module if hasattr(model, 'module') else model
|
||||
|
||||
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
|
||||
logger.info('Reading saved model from %s', model_file)
|
||||
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
|
||||
logger.info('model_state_dict keys %s', state_dict.keys())
|
||||
return CheckpointState(**state_dict)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,396 @@
|
|||
import torch
|
||||
import copy
|
||||
import os
|
||||
from torch import nn
|
||||
import collections
|
||||
from util import logger
|
||||
from train_util import dist_util
|
||||
from transformers import AdamW
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from torch.serialization import default_restore_location
|
||||
from transformers import (
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import (
|
||||
BertModel,
|
||||
BertConfig,
|
||||
)
|
||||
from diffusion_util.resample import LossAwareSampler, UniformSampler
|
||||
|
||||
from data_util.text_data_util import Text_Hidden_dataset, Question_dataset, PandQ_dataset
|
||||
from data_util.s2s_data_util import S2S_dataset, QG_dataset_Diff
|
||||
|
||||
INITIAL_LOG_LOSS_SCALE = 20.0
|
||||
CheckpointState = collections.namedtuple("CheckpointState",
|
||||
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset'])
|
||||
|
||||
'''
|
||||
TrainLoop training class
|
||||
'''
|
||||
class TrainLoop:
|
||||
def __init__(
|
||||
self,
|
||||
train_type,
|
||||
model,
|
||||
diffusion,
|
||||
data,
|
||||
batch_size,
|
||||
lr,
|
||||
ema_rate,
|
||||
log_interval,
|
||||
save_interval,
|
||||
resume_checkpoint,
|
||||
warmup_steps=0,
|
||||
use_fp16=False,
|
||||
fp16_scale_growth=1e-3,
|
||||
schedule_sampler=None,
|
||||
weight_decay=0.0,
|
||||
lr_anneal_steps=0,
|
||||
checkpoint_path='',
|
||||
gradient_clipping=-1.,
|
||||
eval_data=None,
|
||||
eval_interval=-1,
|
||||
gradient_accumulation_steps=1,
|
||||
device=None,
|
||||
data_name="xsum_data",
|
||||
):
|
||||
self.train_type = train_type
|
||||
self.model = model
|
||||
self.diffusion = diffusion
|
||||
self.data = data
|
||||
self.eval_data = eval_data
|
||||
self.batch_size = batch_size
|
||||
self.lr = lr
|
||||
self.ema_rate = (
|
||||
[ema_rate]
|
||||
if isinstance(ema_rate, float)
|
||||
else [float(x) for x in ema_rate.split(",")]
|
||||
)
|
||||
self.log_interval = log_interval
|
||||
self.eval_interval = eval_interval
|
||||
self.save_interval = save_interval
|
||||
self.resume_checkpoint = resume_checkpoint
|
||||
self.warmup_steps = warmup_steps
|
||||
self.use_fp16 = use_fp16
|
||||
self.fp16_scale_growth = fp16_scale_growth
|
||||
self.schedule_sampler = schedule_sampler
|
||||
self.weight_decay = weight_decay
|
||||
self.lr_anneal_steps = lr_anneal_steps
|
||||
self.gradient_clipping = gradient_clipping
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.device = device
|
||||
self.data_name = data_name
|
||||
self.master_params = list(self.model.parameters())
|
||||
self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
|
||||
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.optimizer = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
|
||||
self.scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.lr_anneal_steps
|
||||
)
|
||||
self.global_step = 0
|
||||
|
||||
# load last checkpoint
|
||||
if self.checkpoint_path is not None:
|
||||
model_checkpoint_files = []
|
||||
ema_checkpoint_files = []
|
||||
if os.path.exists(self.checkpoint_path):
|
||||
for item in os.scandir(self.checkpoint_path):
|
||||
if item.is_file():
|
||||
if "model_checkpoint" in item.path:
|
||||
model_checkpoint_files.append(item.path)
|
||||
if "ema" in item.path:
|
||||
ema_checkpoint_files.append(item.path)
|
||||
if len(model_checkpoint_files) != 0 and len(ema_checkpoint_files) != 0:
|
||||
model_checkpoint_files.sort(key=lambda f: int(f.split('model_checkpoint-')[1]), reverse=True)
|
||||
logger.info("***** load " + model_checkpoint_files[0] + " *****")
|
||||
ema_checkpoint_files.sort(key=lambda f: int(f.split('checkpoint-')[-1]), reverse=True)
|
||||
logger.info("***** load " + ema_checkpoint_files[0] + " *****")
|
||||
model_saved_state = load_states_from_checkpoint(model_checkpoint_files[0])
|
||||
self.global_step = self._load_saved_state(model_saved_state)
|
||||
self.ema_params = [
|
||||
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
|
||||
]
|
||||
else:
|
||||
self.ema_params = [
|
||||
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
|
||||
]
|
||||
logger.info("***** there are no checkpoint in" + self.checkpoint_path + " *****")
|
||||
|
||||
else:
|
||||
self.ema_params = [
|
||||
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
|
||||
]
|
||||
|
||||
# model to DDP
|
||||
if dist.get_world_size() > 1:
|
||||
self.model = DDP(
|
||||
self.model, device_ids=[dist.get_rank()], output_device=dist.get_rank(), find_unused_parameters=False,
|
||||
)
|
||||
else:
|
||||
print("single GPU is not achieve now")
|
||||
exit(0)
|
||||
|
||||
def run_loop(self):
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Max steps = %d", self.lr_anneal_steps)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", self.batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
self.batch_size
|
||||
* self.gradient_accumulation_steps
|
||||
* (dist.get_world_size()),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", self.gradient_accumulation_steps)
|
||||
self.model.zero_grad()
|
||||
self.model.train()
|
||||
|
||||
# ddp data sample
|
||||
if self.train_type == 'LM_Diffusion':
|
||||
train_sample = DistributedSampler(self.data)
|
||||
train_dataloader = DataLoader(self.data, sampler=train_sample, batch_size=self.batch_size, drop_last=False,
|
||||
num_workers=20, collate_fn=Question_dataset.get_collate_fn())
|
||||
elif self.train_type == 'S2S_Diffusion':
|
||||
train_sample = DistributedSampler(self.data)
|
||||
'''
|
||||
for s2s
|
||||
'''
|
||||
train_dataloader = DataLoader(self.data, sampler=train_sample, batch_size=self.batch_size, drop_last=False,
|
||||
num_workers=20, collate_fn=S2S_dataset.get_collate_fn())
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
while self.global_step < self.lr_anneal_steps:
|
||||
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=dist.get_rank() not in [-1, 0])
|
||||
|
||||
for batch in epoch_iterator:
|
||||
self.model.train()
|
||||
|
||||
# forward loss
|
||||
self.forward_backward(batch)
|
||||
if self.use_fp16:
|
||||
pass
|
||||
else:
|
||||
# gradient clip
|
||||
if self.gradient_clipping > 0:
|
||||
self.grad_clip()
|
||||
self._log_grad_norm()
|
||||
self.optimizer.step()
|
||||
# lr scheduler
|
||||
self.scheduler.step()
|
||||
self.model.zero_grad()
|
||||
# ema
|
||||
for rate, params in zip(self.ema_rate, self.ema_params):
|
||||
self.update_ema(params, self.master_params, rate=rate)
|
||||
self.global_step += 1
|
||||
self.log_step()
|
||||
|
||||
if self.global_step % self.log_interval == 0:
|
||||
logger.dumpkvs()
|
||||
|
||||
if self.eval_data is not None and self.global_step % self.eval_interval == 0:
|
||||
if dist.get_rank() == 0:
|
||||
print('eval on validation set...')
|
||||
if self.train_type == 'LM_Diffusion':
|
||||
dev_dataloader = DataLoader(self.eval_data, batch_size=self.batch_size,
|
||||
drop_last=False,
|
||||
num_workers=20, collate_fn=Question_dataset.get_collate_fn())
|
||||
elif self.train_type == 'S2S_Diffusion':
|
||||
'''
|
||||
for s2s
|
||||
'''
|
||||
dev_dataloader = DataLoader(self.eval_data, batch_size=self.batch_size,
|
||||
drop_last=False,
|
||||
num_workers=20, collate_fn=S2S_dataset.get_collate_fn())
|
||||
else:
|
||||
return NotImplementedError
|
||||
for step, batch in enumerate(dev_dataloader):
|
||||
self.forward_only(batch)
|
||||
if step > 10:
|
||||
break
|
||||
logger.dumpkvs()
|
||||
# save
|
||||
if self.global_step % self.save_interval == 0:
|
||||
self.save()
|
||||
|
||||
def save(self):
|
||||
|
||||
def save_checkpoint(rate, ema_params):
|
||||
model_to_save = get_model_obj(self.model)
|
||||
if not rate:
|
||||
model_state_dict = model_to_save.state_dict()
|
||||
else:
|
||||
model_state_dict = model_to_save.state_dict()
|
||||
for i, (name, _value) in enumerate(model_to_save.named_parameters()):
|
||||
assert name in model_state_dict
|
||||
model_state_dict[name] = ema_params[i]
|
||||
|
||||
opt_state_dict = self.optimizer.state_dict()
|
||||
sch_state_dict = self.scheduler.state_dict()
|
||||
offset = self.global_step
|
||||
state = CheckpointState(model_state_dict,
|
||||
opt_state_dict,
|
||||
sch_state_dict,
|
||||
offset,
|
||||
)
|
||||
if not rate:
|
||||
ckpt_path = os.path.join(self.checkpoint_path, 'model_checkpoint-' + str(offset))
|
||||
else:
|
||||
ckpt_path = os.path.join(self.checkpoint_path, 'ema_' + str(rate) + '_checkpoint-' + str(offset))
|
||||
|
||||
torch.save(state._asdict(), ckpt_path)
|
||||
logger.info('Saved checkpoint at %s', ckpt_path)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
save_checkpoint(0, None)
|
||||
for rate, params in zip(self.ema_rate, self.ema_params):
|
||||
save_checkpoint(rate, params)
|
||||
|
||||
|
||||
def forward_backward(self, batch):
|
||||
|
||||
if self.train_type == 'LM_Diffusion':
|
||||
t, weights = self.schedule_sampler.sample(batch[1].shape[0], self.device)
|
||||
inputs_text = {"query_ids": batch[1].long().to(self.device),
|
||||
"attention_mask_q": batch[2].long().to(self.device)}
|
||||
losses = self.diffusion.training_losses(self.model, inputs_text, t)
|
||||
elif self.train_type == 'S2S_Diffusion':
|
||||
|
||||
'''
|
||||
for s2s
|
||||
'''
|
||||
t, weights = self.schedule_sampler.sample(batch['src_input_ids'].shape[0], self.device)
|
||||
losses = self.diffusion.training_losses(self.model, batch, t)
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
if isinstance(self.schedule_sampler, LossAwareSampler):
|
||||
self.schedule_sampler.update_with_local_losses(
|
||||
t, losses["loss"].detach()
|
||||
)
|
||||
|
||||
loss = (losses["loss"] * weights).mean()
|
||||
log_loss_dict(
|
||||
self.diffusion, t, {k: v * weights for k, v in losses.items()}
|
||||
)
|
||||
|
||||
if self.use_fp16:
|
||||
loss_scale = 2 ** self.lg_loss_scale
|
||||
(loss * loss_scale).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
def forward_only(self, batch):
|
||||
with torch.no_grad():
|
||||
self.model.zero_grad()
|
||||
if self.train_type == 'LM_Diffusion':
|
||||
t, weights = self.schedule_sampler.sample(batch[1].shape[0], dist_util.dev())
|
||||
inputs_text = {"query_ids": batch[1].long().to(self.device),
|
||||
"attention_mask_q": batch[2].long().to(self.device)}
|
||||
|
||||
losses = self.diffusion.training_losses(self.model, inputs_text, t)
|
||||
elif self.train_type == 'S2S_Diffusion':
|
||||
|
||||
'''
|
||||
for s2s
|
||||
'''
|
||||
t, weights = self.schedule_sampler.sample(batch['src_input_ids'].shape[0], self.device)
|
||||
losses = self.diffusion.training_losses(self.model, batch, t)
|
||||
else:
|
||||
return NotImplementedError
|
||||
|
||||
log_loss_dict(
|
||||
self.diffusion, t, {f"eval_{k}": v * weights for k, v in losses.items()}
|
||||
)
|
||||
|
||||
def _log_grad_norm(self):
|
||||
sqsum = 0.0
|
||||
for p in self.master_params:
|
||||
# print(p)
|
||||
sqsum += (p.grad ** 2).sum().item()
|
||||
logger.logkv_mean("grad_norm", np.sqrt(sqsum))
|
||||
|
||||
def log_step(self):
|
||||
logger.logkv("step", self.global_step)
|
||||
if self.use_fp16:
|
||||
logger.logkv("lg_loss_scale", self.lg_loss_scale)
|
||||
|
||||
def _anneal_lr(self):
|
||||
if not self.lr_anneal_steps:
|
||||
return
|
||||
frac_done = self.global_step / self.lr_anneal_steps
|
||||
lr = self.lr * (1 - frac_done)
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group["lr"] = lr
|
||||
|
||||
def grad_clip(self):
|
||||
# print('doing gradient clipping')
|
||||
max_grad_norm=self.gradient_clipping #3.0
|
||||
if hasattr(self.optimizer, "clip_grad_norm"):
|
||||
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
|
||||
self.optimizer.clip_grad_norm(max_grad_norm)
|
||||
# else:
|
||||
# assert False
|
||||
# elif hasattr(self.model, "clip_grad_norm_"):
|
||||
# # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
|
||||
# self.model.clip_grad_norm_(args.max_grad_norm)
|
||||
else:
|
||||
# Revert to normal clipping otherwise, handling Apex or full precision
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), #amp.master_params(self.opt) if self.use_apex else
|
||||
max_grad_norm,
|
||||
)
|
||||
|
||||
def _load_saved_state(self, saved_state: CheckpointState):
|
||||
self.global_step = saved_state.offset
|
||||
logger.info('Loading checkpoint @ step=%s', self.global_step)
|
||||
|
||||
logger.info('Loading saved model state ...')
|
||||
self.model.load_state_dict(saved_state.model_dict) # set strict=False if you use extra projection
|
||||
self.optimizer.load_state_dict(saved_state.optimizer_dict)
|
||||
self.scheduler.load_state_dict(saved_state.scheduler_dict)
|
||||
self.master_params = list(self.model.parameters())
|
||||
return self.global_step
|
||||
|
||||
|
||||
def update_ema(self, target_params, source_params, rate=0.99):
|
||||
"""
|
||||
Update target parameters to be closer to those of source parameters using
|
||||
an exponential moving average.
|
||||
|
||||
:param target_params: the target parameter sequence.
|
||||
:param source_params: the source parameter sequence.
|
||||
:param rate: the EMA rate (closer to 1 means slower).
|
||||
"""
|
||||
for targ, src in zip(target_params, source_params):
|
||||
# print("target_params:", targ.device)
|
||||
# print("source_params:", src.device)
|
||||
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
||||
|
||||
|
||||
|
||||
def log_loss_dict(diffusion, ts, losses):
|
||||
for key, values in losses.items():
|
||||
logger.logkv_mean(key, values.mean().item())
|
||||
# Log the quantiles (four quartiles, in particular).
|
||||
for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
|
||||
quartile = int(4 * sub_t / diffusion.num_timesteps)
|
||||
logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
|
||||
|
||||
def get_model_obj(model: nn.Module):
|
||||
return model.module if hasattr(model, 'module') else model
|
||||
|
||||
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
|
||||
logger.info('Reading saved model from %s', model_file)
|
||||
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
|
||||
logger.info('model_state_dict keys %s', state_dict.keys())
|
||||
return CheckpointState(**state_dict)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,498 @@
|
|||
"""
|
||||
Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
|
||||
https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import os.path as osp
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
# import wandb
|
||||
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARN = 30
|
||||
ERROR = 40
|
||||
|
||||
DISABLED = 50
|
||||
|
||||
|
||||
class KVWriter(object):
|
||||
def writekvs(self, kvs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SeqWriter(object):
|
||||
def writeseq(self, seq):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HumanOutputFormat(KVWriter, SeqWriter):
|
||||
def __init__(self, filename_or_file):
|
||||
if isinstance(filename_or_file, str):
|
||||
self.file = open(filename_or_file, "wt")
|
||||
self.own_file = True
|
||||
else:
|
||||
assert hasattr(filename_or_file, "read"), (
|
||||
"expected file or str, got %s" % filename_or_file
|
||||
)
|
||||
self.file = filename_or_file
|
||||
self.own_file = False
|
||||
|
||||
def writekvs(self, kvs):
|
||||
# Create strings for printing
|
||||
key2str = {}
|
||||
for (key, val) in sorted(kvs.items()):
|
||||
if hasattr(val, "__float__"):
|
||||
valstr = "%-8.3g" % val
|
||||
else:
|
||||
valstr = str(val)
|
||||
key2str[self._truncate(key)] = self._truncate(valstr)
|
||||
|
||||
# Find max widths
|
||||
if len(key2str) == 0:
|
||||
print("WARNING: tried to write empty key-value dict")
|
||||
return
|
||||
else:
|
||||
keywidth = max(map(len, key2str.keys()))
|
||||
valwidth = max(map(len, key2str.values()))
|
||||
|
||||
# Write out the data
|
||||
dashes = "-" * (keywidth + valwidth + 7)
|
||||
lines = [dashes]
|
||||
for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
|
||||
lines.append(
|
||||
"| %s%s | %s%s |"
|
||||
% (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
|
||||
)
|
||||
lines.append(dashes)
|
||||
self.file.write("\n".join(lines) + "\n")
|
||||
|
||||
# Flush the output to the file
|
||||
self.file.flush()
|
||||
|
||||
def _truncate(self, s):
|
||||
maxlen = 30
|
||||
return s[: maxlen - 3] + "..." if len(s) > maxlen else s
|
||||
|
||||
def writeseq(self, seq):
|
||||
seq = list(seq)
|
||||
for (i, elem) in enumerate(seq):
|
||||
self.file.write(elem)
|
||||
if i < len(seq) - 1: # add space unless this is the last one
|
||||
self.file.write(" ")
|
||||
self.file.write("\n")
|
||||
self.file.flush()
|
||||
|
||||
def close(self):
|
||||
if self.own_file:
|
||||
self.file.close()
|
||||
|
||||
|
||||
class JSONOutputFormat(KVWriter):
|
||||
def __init__(self, filename):
|
||||
self.file = open(filename, "wt")
|
||||
|
||||
def writekvs(self, kvs):
|
||||
for k, v in sorted(kvs.items()):
|
||||
if hasattr(v, "dtype"):
|
||||
kvs[k] = float(v)
|
||||
self.file.write(json.dumps(kvs) + "\n")
|
||||
self.file.flush()
|
||||
|
||||
def close(self):
|
||||
self.file.close()
|
||||
|
||||
|
||||
class CSVOutputFormat(KVWriter):
|
||||
def __init__(self, filename):
|
||||
self.file = open(filename, "w+t")
|
||||
self.keys = []
|
||||
self.sep = ","
|
||||
|
||||
def writekvs(self, kvs):
|
||||
# Add our current row to the history
|
||||
extra_keys = list(kvs.keys() - self.keys)
|
||||
extra_keys.sort()
|
||||
if extra_keys:
|
||||
self.keys.extend(extra_keys)
|
||||
self.file.seek(0)
|
||||
lines = self.file.readlines()
|
||||
self.file.seek(0)
|
||||
for (i, k) in enumerate(self.keys):
|
||||
if i > 0:
|
||||
self.file.write(",")
|
||||
self.file.write(k)
|
||||
self.file.write("\n")
|
||||
for line in lines[1:]:
|
||||
self.file.write(line[:-1])
|
||||
self.file.write(self.sep * len(extra_keys))
|
||||
self.file.write("\n")
|
||||
for (i, k) in enumerate(self.keys):
|
||||
if i > 0:
|
||||
self.file.write(",")
|
||||
v = kvs.get(k)
|
||||
if v is not None:
|
||||
self.file.write(str(v))
|
||||
self.file.write("\n")
|
||||
self.file.flush()
|
||||
|
||||
def close(self):
|
||||
self.file.close()
|
||||
|
||||
|
||||
class TensorBoardOutputFormat(KVWriter):
|
||||
"""
|
||||
Dumps key/value pairs into TensorBoard's numeric format.
|
||||
"""
|
||||
|
||||
def __init__(self, dir):
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
self.dir = dir
|
||||
self.step = 1
|
||||
prefix = "events"
|
||||
path = osp.join(osp.abspath(dir), prefix)
|
||||
import tensorflow as tf
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
self.tf = tf
|
||||
self.event_pb2 = event_pb2
|
||||
self.pywrap_tensorflow = pywrap_tensorflow
|
||||
self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
|
||||
|
||||
def writekvs(self, kvs):
|
||||
def summary_val(k, v):
|
||||
kwargs = {"tag": k, "simple_value": float(v)}
|
||||
return self.tf.Summary.Value(**kwargs)
|
||||
|
||||
summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
|
||||
event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
|
||||
event.step = (
|
||||
self.step
|
||||
) # is there any reason why you'd want to specify the step?
|
||||
self.writer.WriteEvent(event)
|
||||
self.writer.Flush()
|
||||
self.step += 1
|
||||
|
||||
def close(self):
|
||||
if self.writer:
|
||||
self.writer.Close()
|
||||
self.writer = None
|
||||
|
||||
|
||||
def make_output_format(format, ev_dir, log_suffix=""):
|
||||
os.makedirs(ev_dir, exist_ok=True)
|
||||
if format == "stdout":
|
||||
return HumanOutputFormat(sys.stdout)
|
||||
elif format == "log":
|
||||
return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
|
||||
elif format == "json":
|
||||
return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
|
||||
elif format == "csv":
|
||||
return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
|
||||
elif format == "tensorboard":
|
||||
return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
|
||||
else:
|
||||
raise ValueError("Unknown format specified: %s" % (format,))
|
||||
|
||||
|
||||
# ================================================================
|
||||
# API
|
||||
# ================================================================
|
||||
|
||||
|
||||
def logkv(key, val):
|
||||
"""
|
||||
Log a value of some diagnostic
|
||||
Call this once for each diagnostic quantity, each iteration
|
||||
If called many times, last value will be used.
|
||||
"""
|
||||
get_current().logkv(key, val)
|
||||
|
||||
|
||||
def logkv_mean(key, val):
|
||||
"""
|
||||
The same as logkv(), but if called many times, values averaged.
|
||||
"""
|
||||
get_current().logkv_mean(key, val)
|
||||
|
||||
|
||||
def logkvs(d):
|
||||
"""
|
||||
Log a dictionary of key-value pairs
|
||||
"""
|
||||
for (k, v) in d.items():
|
||||
logkv(k, v)
|
||||
|
||||
|
||||
def dumpkvs():
|
||||
"""
|
||||
Write all of the diagnostics from the current iteration
|
||||
"""
|
||||
return get_current().dumpkvs()
|
||||
|
||||
|
||||
def getkvs():
|
||||
return get_current().name2val
|
||||
|
||||
|
||||
def log(*args, level=INFO):
|
||||
"""
|
||||
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
|
||||
"""
|
||||
get_current().log(*args, level=level)
|
||||
|
||||
|
||||
def debug(*args):
|
||||
log(*args, level=DEBUG)
|
||||
|
||||
|
||||
def info(*args):
|
||||
log(*args, level=INFO)
|
||||
|
||||
|
||||
def warn(*args):
|
||||
log(*args, level=WARN)
|
||||
|
||||
|
||||
def error(*args):
|
||||
log(*args, level=ERROR)
|
||||
|
||||
|
||||
def set_level(level):
|
||||
"""
|
||||
Set logging threshold on current logger.
|
||||
"""
|
||||
get_current().set_level(level)
|
||||
|
||||
|
||||
def set_comm(comm):
|
||||
get_current().set_comm(comm)
|
||||
|
||||
|
||||
def get_dir():
|
||||
"""
|
||||
Get directory that log files are being written to.
|
||||
will be None if there is no output directory (i.e., if you didn't call start)
|
||||
"""
|
||||
return get_current().get_dir()
|
||||
|
||||
|
||||
record_tabular = logkv
|
||||
dump_tabular = dumpkvs
|
||||
|
||||
|
||||
@contextmanager
|
||||
def profile_kv(scopename):
|
||||
logkey = "wait_" + scopename
|
||||
tstart = time.time()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
get_current().name2val[logkey] += time.time() - tstart
|
||||
|
||||
|
||||
def profile(n):
|
||||
"""
|
||||
Usage:
|
||||
@profile("my_func")
|
||||
def my_func(): code
|
||||
"""
|
||||
|
||||
def decorator_with_name(func):
|
||||
def func_wrapper(*args, **kwargs):
|
||||
with profile_kv(n):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return func_wrapper
|
||||
|
||||
return decorator_with_name
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Backend
|
||||
# ================================================================
|
||||
|
||||
|
||||
def get_current():
|
||||
if Logger.CURRENT is None:
|
||||
_configure_default_logger()
|
||||
|
||||
return Logger.CURRENT
|
||||
|
||||
|
||||
class Logger(object):
|
||||
DEFAULT = None # A logger with no output files. (See right below class definition)
|
||||
# So that you can still log to the terminal without setting up any output files
|
||||
CURRENT = None # Current logger being used by the free functions above
|
||||
|
||||
def __init__(self, dir, output_formats, comm=None):
|
||||
self.name2val = defaultdict(float) # values this iteration
|
||||
self.name2cnt = defaultdict(int)
|
||||
self.level = INFO
|
||||
self.dir = dir
|
||||
self.output_formats = output_formats
|
||||
self.comm = comm
|
||||
|
||||
# Logging API, forwarded
|
||||
# ----------------------------------------
|
||||
def logkv(self, key, val):
|
||||
self.name2val[key] = val
|
||||
|
||||
def logkv_mean(self, key, val):
|
||||
oldval, cnt = self.name2val[key], self.name2cnt[key]
|
||||
self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
|
||||
self.name2cnt[key] = cnt + 1
|
||||
|
||||
def dumpkvs(self, prefix=None):
|
||||
if self.comm is None:
|
||||
d = self.name2val
|
||||
else:
|
||||
d = mpi_weighted_mean(
|
||||
self.comm,
|
||||
{
|
||||
name: (val, self.name2cnt.get(name, 1))
|
||||
for (name, val) in self.name2val.items()
|
||||
},
|
||||
)
|
||||
if self.comm.rank != 0:
|
||||
d["dummy"] = 1 # so we don't get a warning about empty dict
|
||||
# LISA
|
||||
# wandb.log({**d})
|
||||
out = d.copy() # Return the dict for unit testing purposes
|
||||
for fmt in self.output_formats:
|
||||
if isinstance(fmt, KVWriter):
|
||||
fmt.writekvs(d)
|
||||
self.name2val.clear()
|
||||
self.name2cnt.clear()
|
||||
return out
|
||||
|
||||
def log(self, *args, level=INFO):
|
||||
if self.level <= level:
|
||||
self._do_log(args)
|
||||
|
||||
# Configuration
|
||||
# ----------------------------------------
|
||||
def set_level(self, level):
|
||||
self.level = level
|
||||
|
||||
def set_comm(self, comm):
|
||||
self.comm = comm
|
||||
|
||||
def get_dir(self):
|
||||
return self.dir
|
||||
|
||||
def close(self):
|
||||
for fmt in self.output_formats:
|
||||
fmt.close()
|
||||
|
||||
# Misc
|
||||
# ----------------------------------------
|
||||
def _do_log(self, args):
|
||||
for fmt in self.output_formats:
|
||||
if isinstance(fmt, SeqWriter):
|
||||
fmt.writeseq(map(str, args))
|
||||
|
||||
|
||||
def get_rank_without_mpi_import():
|
||||
# check environment variables here instead of importing mpi4py
|
||||
# to avoid calling MPI_Init() when this module is imported
|
||||
for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
|
||||
if varname in os.environ:
|
||||
return int(os.environ[varname])
|
||||
return 0
|
||||
|
||||
|
||||
def mpi_weighted_mean(comm, local_name2valcount):
|
||||
"""
|
||||
Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
|
||||
Perform a weighted average over dicts that are each on a different node
|
||||
Input: local_name2valcount: dict mapping key -> (value, count)
|
||||
Returns: key -> mean
|
||||
"""
|
||||
all_name2valcount = comm.gather(local_name2valcount)
|
||||
if comm.rank == 0:
|
||||
name2sum = defaultdict(float)
|
||||
name2count = defaultdict(float)
|
||||
for n2vc in all_name2valcount:
|
||||
for (name, (val, count)) in n2vc.items():
|
||||
try:
|
||||
val = float(val)
|
||||
except ValueError:
|
||||
if comm.rank == 0:
|
||||
warnings.warn(
|
||||
"WARNING: tried to compute mean on non-float {}={}".format(
|
||||
name, val
|
||||
)
|
||||
)
|
||||
else:
|
||||
name2sum[name] += val * count
|
||||
name2count[name] += count
|
||||
return {name: name2sum[name] / name2count[name] for name in name2sum}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
|
||||
"""
|
||||
If comm is provided, average all numerical stats across that comm
|
||||
"""
|
||||
if dir is None:
|
||||
dir = os.getenv("OPENAI_LOGDIR")
|
||||
if dir is None:
|
||||
dir = osp.join(
|
||||
tempfile.gettempdir(),
|
||||
datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
|
||||
)
|
||||
assert isinstance(dir, str)
|
||||
dir = os.path.expanduser(dir)
|
||||
os.makedirs(os.path.expanduser(dir), exist_ok=True)
|
||||
|
||||
rank = get_rank_without_mpi_import()
|
||||
if rank > 0:
|
||||
log_suffix = log_suffix + "-rank%03i" % rank
|
||||
|
||||
if format_strs is None:
|
||||
if rank == 0:
|
||||
format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
|
||||
else:
|
||||
format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
|
||||
format_strs = filter(None, format_strs)
|
||||
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
|
||||
|
||||
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
|
||||
if output_formats:
|
||||
log("Logging to %s" % dir)
|
||||
|
||||
|
||||
def _configure_default_logger():
|
||||
configure()
|
||||
Logger.DEFAULT = Logger.CURRENT
|
||||
|
||||
|
||||
def reset():
|
||||
if Logger.CURRENT is not Logger.DEFAULT:
|
||||
Logger.CURRENT.close()
|
||||
Logger.CURRENT = Logger.DEFAULT
|
||||
log("Reset logger")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def scoped_configure(dir=None, format_strs=None, comm=None):
|
||||
prevlogger = Logger.CURRENT
|
||||
configure(dir=dir, format_strs=format_strs, comm=comm)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
Logger.CURRENT.close()
|
||||
Logger.CURRENT = prevlogger
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
"""
|
||||
Helpers for various likelihood-based losses. These are ported from the original
|
||||
Ho et al. diffusion models codebase:
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch as th
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
Compute the KL divergence between two gaussians.
|
||||
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, th.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for th.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
# print(logvar2.shape)
|
||||
# temp1 = 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2))
|
||||
# print(f'const = {temp1.mean()}, coef={(th.exp(-logvar2) * 0.5).mean()}, mse={((mean1 - mean2) ** 2).mean().item()}')
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ th.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
||||
)
|
||||
|
||||
|
||||
def approx_standard_normal_cdf(x):
|
||||
"""
|
||||
A fast approximation of the cumulative distribution function of the
|
||||
standard normal.
|
||||
"""
|
||||
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
||||
|
||||
|
||||
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
||||
"""
|
||||
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
||||
given image.
|
||||
|
||||
:param x: the target images. It is assumed that this was uint8 values,
|
||||
rescaled to the range [-1, 1].
|
||||
:param means: the Gaussian mean Tensor.
|
||||
:param log_scales: the Gaussian log stddev Tensor.
|
||||
:return: a tensor like x of log probabilities (in nats).
|
||||
"""
|
||||
assert x.shape == means.shape == log_scales.shape
|
||||
centered_x = x - means
|
||||
inv_stdv = th.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
||||
cdf_plus = approx_standard_normal_cdf(plus_in)
|
||||
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
||||
cdf_min = approx_standard_normal_cdf(min_in)
|
||||
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
||||
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
log_probs = th.where(
|
||||
x < -0.999,
|
||||
log_cdf_plus,
|
||||
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
||||
)
|
||||
assert log_probs.shape == x.shape
|
||||
return log_probs
|
||||
|
||||
def gaussian_density(x, *, means, log_scales):
|
||||
from torch.distributions import Normal
|
||||
normal_dist = Normal(means, log_scales.exp())
|
||||
logp = normal_dist.log_prob(x)
|
||||
return logp
|
||||
|
||||
|
||||
def discretized_text_log_likelihood(x, *, means, log_scales):
|
||||
"""
|
||||
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
||||
given image.
|
||||
|
||||
:param x: the target images. It is assumed that this was uint8 values,
|
||||
rescaled to the range [-1, 1].
|
||||
:param means: the Gaussian mean Tensor.
|
||||
:param log_scales: the Gaussian log stddev Tensor.
|
||||
:return: a tensor like x of log probabilities (in nats).
|
||||
"""
|
||||
print(x.shape, means.shape)
|
||||
# assert x.shape == means.shape == log_scales.shape
|
||||
print(x, means)
|
||||
centered_x = x - means
|
||||
inv_stdv = th.exp(-log_scales)
|
||||
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
||||
cdf_plus = approx_standard_normal_cdf(plus_in)
|
||||
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
||||
cdf_min = approx_standard_normal_cdf(min_in)
|
||||
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
||||
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
||||
cdf_delta = cdf_plus - cdf_min
|
||||
log_probs = th.where(
|
||||
x < -0.999,
|
||||
log_cdf_plus,
|
||||
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
||||
)
|
||||
assert log_probs.shape == x.shape
|
||||
return log_probs
|
|
@ -0,0 +1,170 @@
|
|||
"""
|
||||
Various utilities for neural networks.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * th.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def update_ema(target_params, source_params, rate=0.99):
|
||||
"""
|
||||
Update target parameters to be closer to those of source parameters using
|
||||
an exponential moving average.
|
||||
|
||||
:param target_params: the target parameter sequence.
|
||||
:param source_params: the source parameter sequence.
|
||||
:param rate: the EMA rate (closer to 1 means slower).
|
||||
"""
|
||||
for targ, src in zip(target_params, source_params):
|
||||
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = th.exp(
|
||||
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(th.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
with th.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with th.enable_grad():
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = th.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
|
@ -0,0 +1,158 @@
|
|||
from model.Diffusion_LM import Diffusion_LM, CrossAttention_Diffusion_LM
|
||||
from diffusion_util import gaussian_diffusion as gd
|
||||
from diffusion_util.respace import SpacedDiffusion, space_timesteps
|
||||
from diffusion_util.gaussian_diffusion import GaussianDiffusion
|
||||
|
||||
|
||||
|
||||
def create_model_and_diffusion(
|
||||
args
|
||||
):
|
||||
model = create_model(
|
||||
model_channels=args.model_channels,
|
||||
learn_sigma=args.learn_sigma,
|
||||
dropout=args.dropout,
|
||||
model_arch=args.model_arch,
|
||||
in_channel=args.in_channel,
|
||||
out_channel=args.out_channel,
|
||||
vocab_size=args.vocab_size,
|
||||
config_name=args.config_name,
|
||||
logits_mode=args.logits_mode,
|
||||
init_pretrained=args.init_pretrained,
|
||||
token_emb_type=args.token_emb_type,
|
||||
)
|
||||
diffusion = create_gaussian_diffusion(
|
||||
steps=args.diffusion_steps,
|
||||
learn_sigma=args.learn_sigma,
|
||||
sigma_small=args.sigma_small,
|
||||
noise_schedule=args.noise_schedule,
|
||||
use_kl=args.use_kl,
|
||||
predict_xstart=args.predict_xstart,
|
||||
rescale_timesteps=args.rescale_timesteps,
|
||||
rescale_learned_sigmas=args.rescale_learned_sigmas,
|
||||
model_arch=args.model_arch,
|
||||
training_mode=args.training_mode,
|
||||
)
|
||||
return model, diffusion
|
||||
|
||||
'''
|
||||
create diffusion model
|
||||
'''
|
||||
def create_model(
|
||||
model_channels,
|
||||
learn_sigma,
|
||||
dropout,
|
||||
model_arch,
|
||||
in_channel=8,
|
||||
out_channel=8,
|
||||
vocab_size=None,
|
||||
config_name='',
|
||||
logits_mode=1,
|
||||
init_pretrained=True,
|
||||
token_emb_type='pretrain',
|
||||
):
|
||||
print(f'creating model, based on {model_arch}')
|
||||
|
||||
if model_arch == 'transformer':
|
||||
|
||||
return Diffusion_LM(
|
||||
in_channels=in_channel,
|
||||
model_channels=model_channels,
|
||||
out_channels=(out_channel if not learn_sigma else out_channel*2),
|
||||
dropout=dropout,
|
||||
config_name=config_name,
|
||||
vocab_size=vocab_size,
|
||||
logits_mode=logits_mode,
|
||||
init_pretrained=init_pretrained,
|
||||
token_emb_type=token_emb_type,
|
||||
)
|
||||
|
||||
elif model_arch == 's2s_CAT':
|
||||
|
||||
return CrossAttention_Diffusion_LM(
|
||||
in_channels=in_channel,
|
||||
model_channels=model_channels,
|
||||
out_channels=(out_channel if not learn_sigma else out_channel * 2),
|
||||
dropout=dropout,
|
||||
config_name=config_name,
|
||||
vocab_size=vocab_size,
|
||||
logits_mode=logits_mode,
|
||||
init_pretrained=init_pretrained,
|
||||
token_emb_type=token_emb_type,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
'''
|
||||
create diffusion process
|
||||
'''
|
||||
def create_gaussian_diffusion(
|
||||
steps=1000,
|
||||
learn_sigma=False,
|
||||
sigma_small=False,
|
||||
noise_schedule="cosine",
|
||||
use_kl=False,
|
||||
predict_xstart=False,
|
||||
rescale_timesteps=False,
|
||||
rescale_learned_sigmas=False,
|
||||
timestep_respacing="",
|
||||
model_arch='transformer',
|
||||
training_mode='e2e',
|
||||
):
|
||||
|
||||
# β , Determine according to the maximum T and variance schedule
|
||||
print("noise_schedule: ", noise_schedule)
|
||||
print("Diffusion Steps: ", steps)
|
||||
betas = gd.get_named_beta_schedule(noise_schedule, steps)
|
||||
print("betas: ", betas)
|
||||
|
||||
# determine the loss function used in training
|
||||
if training_mode == 'e2e' or training_mode == 's2s':
|
||||
# end to end training
|
||||
if use_kl:
|
||||
loss_type = gd.LossType.E2E_KL
|
||||
else:
|
||||
loss_type = gd.LossType.E2E_MSE
|
||||
elif training_mode == 'e2e-simple':
|
||||
if use_kl:
|
||||
loss_type = gd.LossType.E2E_Simple_KL
|
||||
else:
|
||||
loss_type = gd.LossType.E2E_Simple_MSE
|
||||
|
||||
else:
|
||||
if use_kl:
|
||||
loss_type = gd.LossType.RESCALED_KL
|
||||
elif rescale_learned_sigmas:
|
||||
loss_type = gd.LossType.RESCALED_MSE
|
||||
else:
|
||||
loss_type = gd.LossType.MSE
|
||||
|
||||
|
||||
if not timestep_respacing:
|
||||
timestep_respacing = [steps]
|
||||
print("Diffusion Loss Type: ", loss_type, " , Whether to learn sigma: ", learn_sigma)
|
||||
print("Diffusion predict xstart: ", predict_xstart)
|
||||
return GaussianDiffusion(
|
||||
betas=betas,
|
||||
model_mean_type=(
|
||||
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
|
||||
),
|
||||
model_var_type=(
|
||||
(
|
||||
gd.ModelVarType.FIXED_LARGE
|
||||
if not sigma_small
|
||||
else gd.ModelVarType.FIXED_SMALL
|
||||
)
|
||||
if not learn_sigma
|
||||
else gd.ModelVarType.LEARNED_RANGE
|
||||
),
|
||||
loss_type=loss_type,
|
||||
rescale_timesteps=rescale_timesteps,
|
||||
model_arch=model_arch,
|
||||
training_mode=training_mode,
|
||||
)
|
||||
|
||||
|
||||
def args_to_dict(args, keys):
|
||||
return {k: getattr(args, k) for k in keys}
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче