This commit is contained in:
weizhen 2023-02-16 09:28:46 +08:00
Родитель 0a1b59cb95
Коммит b9da891396
117 изменённых файлов: 7699 добавлений и 155 удалений

230
GENIE/Genie_Finetune.py Normal file
Просмотреть файл

@ -0,0 +1,230 @@
import argparse
import os
from transformers import set_seed
from diffusion_util.resample import create_named_schedule_sampler
from transformers import AutoTokenizer
import json
from util import logger
from train_util import dist_util
import torch
import torch.distributed as dist
from util.util import (
create_model_and_diffusion,
args_to_dict,
)
import collections
from data_util.s2s_data_util import load_s2s_data
from train_util.train_util import TrainLoop
from torch.serialization import default_restore_location
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
CheckpointState = collections.namedtuple("CheckpointState",
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset'])
def get_arguments():
parser = argparse.ArgumentParser()
# out path
parser.add_argument('--checkpoint_path', type=str, default='', help='output path')
# load pretrain
parser.add_argument('--pretrain_model_path', type=str, default=None, help='using pretraining diffusion')
# load model
parser.add_argument('--model_arch', type=str, default='transformer', help='Core architecture of diffusion model')
parser.add_argument('--model_channels', type=int, default=768, help='Try to set it to the same size as the model hidden')
parser.add_argument('--in_channel', type=int, default=768, help='The input chanel size here must be the same as the word embedding size')
parser.add_argument('--out_channel', type=int, default=768, help='The dimension size of the output is recommended to be the same as that of word embedding for easy reasoning')
parser.add_argument('--dropout', type=float, default=0.1, help='')
parser.add_argument("--learn_sigma", default=False, action="store_true", help="Whether to learning variance")
parser.add_argument('--logits_mode', type=int, default=1, help='final logits mode of Diffusion model')
parser.add_argument('--vocab_size', type=int, default=30522, help='vocab size')
parser.add_argument('--config_name', type=str, default='bert-base-uncased', help='')
parser.add_argument('--token_emb_type', type=str, default='random', help='token embedding type')
parser.add_argument("--init_pretrained", default=False, action="store_true", help="Whether to using pretrain BERT encoder")
parser.add_argument("--fix_encoder", default=False, action="store_true",
help="Whether to training encoder")
# load diffusion
parser.add_argument('--diffusion_steps', type=int, default=2000, help='Diffusion model maximum T')
parser.add_argument('--use_kl', default=False, action="store_true", help="Whether to using kl loss in Diffsion loss")
parser.add_argument('--training_mode', type=str, default='e2e', help='using e2e simple loss or e2e loss or s2s loss')
parser.add_argument('--noise_schedule', type=str, default='sqrt', help='How to plan the noise change of Gaussian distribution')
parser.add_argument('--predict_xstart', default=False, action="store_true", help="Model prediction target, if True, predict xstart, if False, predict EPSILON")
parser.add_argument("--sigma_small", default=False, action="store_true", help="about learning variance")
parser.add_argument("--rescale_learned_sigmas", default=True, action="store_false", help="about learning variance")
parser.add_argument("--rescale_timesteps", default=True, action="store_false", help="about time rescale")
# sample t
parser.add_argument('--schedule_sampler', type=str, default='uniform', help='how to sample t per batch, uniform is Uniform sampling, loss-second-moment is Sampling according to loss')
# data args
parser.add_argument('--data_path', type=str, default='',help='data path')
parser.add_argument('--data_name', type=str, default='', help='data name')
# for seq2seq
parser.add_argument('--src_max_len', type=int, default=144, help='src max len')
parser.add_argument('--tgt_max_len', type=int, default=32, help='tgt max len')
parser.add_argument('--answer_max_len', type=int, default=10, help='tgt max len')
# for doc2query
parser.add_argument('--text_max_len', type=int, default=None, help='text max len')
parser.add_argument('--pas_max_len', type=int, default=None, help='pas max len')
# training args
parser.add_argument('--train_type', type=str, default='LM_Diffusion', help='LM_Diffusion or S2S_Diffusion')
parser.add_argument('--lr_anneal_steps', type=int, default=200000, help='total step')
parser.add_argument('--batch_size', type=int, default=64, help='')
parser.add_argument('--lr', type=float, default=1e-04, help='')
parser.add_argument('--warmup_steps', type=int, default=20000, help='')
parser.add_argument('--ema_rate', type=str, default='0.9999', help='ema training to stable model')
parser.add_argument('--resume_checkpoint', type=str, default=None, help='')
parser.add_argument('--eval_interval', type=int, default=2000, help='')
parser.add_argument('--log_interval', type=int, default=100, help='')
parser.add_argument('--save_interval', type=int, default=50000, help='')
parser.add_argument('--weight_decay', type=str, default=0.0, help='')
parser.add_argument('--gradient_clipping', type=float, default=-1., help='')
parser.add_argument("--use_fp16", default=False, action="store_true", help="about learning variance")
parser.add_argument('--fp16_scale_growth', type=float, default=1e-3, help='')
# seed
parser.add_argument('--seed', type=int, default=101, help='')
# muti-gpu
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
args = parser.parse_args()
return args
def setup_env(args):
if args.local_rank == -1:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1
args.device = device
# store args
if args.local_rank != -1:
args.world_size = torch.distributed.get_world_size()
args.rank = dist.get_rank()
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
logger.info('Reading saved model from %s', model_file)
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
logger.info('model_state_dict keys %s', state_dict.keys())
return CheckpointState(**state_dict)
def main():
# args setting
args = get_arguments()
# out dir set
if dist.get_rank() == 0:
if not os.path.exists(args.checkpoint_path):
os.makedirs(args.checkpoint_path)
# dist.barrier()
logger.log(f'saving the hyperparameters to {args.checkpoint_path}/training_args.json')
with open(f'{args.checkpoint_path}/training_args.json', 'w') as f:
json.dump(args.__dict__, f, indent=2)
# seed setting
set_seed(args.seed)
# dpp setting
setup_env(args)
# dist_util.setup_dist()
# logger setting
log_path = os.path.join(args.checkpoint_path, 'log.txt')
logger.configure(dir=log_path)
model, diffusion = create_model_and_diffusion(
args
)
if args.pretrain_model_path is not None:
print("load model ckpt at :", args.pretrain_model_path)
saved_state = load_states_from_checkpoint(args.pretrain_model_path)
model.load_state_dict(saved_state.model_dict, strict=False)
model.to(args.device)
pytorch_total_params = sum(p.numel() for p in model.parameters())
logger.log(f'the parameter count is {pytorch_total_params}')
'''
time step schedule sampler
'''
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
'''
tokenize
'''
logger.log("loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
'''
for s2s
'''
# load data (train)
train_data = load_s2s_data(
args,
split='train',
padding_mode='max_len',
tokenizer=tokenizer,
)
# load data (dev)
dev_data = load_s2s_data(
args,
split='dev',
padding_mode='max_len',
tokenizer=tokenizer,
)
'''
training
'''
logger.log("training Diffusion LM model...")
TrainLoop(
# training type
train_type=args.train_type,
# Training Core
model=model,
diffusion=diffusion,
data=train_data,
eval_data=dev_data,
schedule_sampler=schedule_sampler,
checkpoint_path=args.checkpoint_path,
# Training Parameters
batch_size=args.batch_size,
lr=args.lr,
ema_rate=args.ema_rate,
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
gradient_clipping=args.gradient_clipping,
# fp16
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
# Training Log
resume_checkpoint=args.resume_checkpoint,
eval_interval=args.eval_interval,
log_interval=args.log_interval,
save_interval=args.save_interval,
# device
device=args.device,
# finetune data name
data_name=args.data_name
).run_loop()
if __name__ == "__main__":
main()

397
GENIE/Genie_Generate.py Normal file
Просмотреть файл

@ -0,0 +1,397 @@
import os
from util import logger
from train_util import dist_util
from util.util import (
create_model_and_diffusion,
args_to_dict,
)
# from transformers import set_seed
import torch
import collections
import argparse
from transformers import AutoTokenizer
import numpy as np
from functools import partial
from data_util.s2s_data_util import load_s2s_data
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from data_util.s2s_data_util import S2S_dataset, QG_dataset_Diff
from torch.serialization import default_restore_location
from transformers import (
BertModel,
BertConfig,
AutoTokenizer,
)
from data_util.text_data_util import load_data_text
from tqdm import tqdm
import random
def get_arguments():
parser = argparse.ArgumentParser()
# out path
parser.add_argument('--generate_path', type=str, default='', help='output path')
parser.add_argument('--eval_model_path', type=str, default='', help='model path')
parser.add_argument('--num_samples', type=int, default=50, help='sample query')
parser.add_argument('--interval_step', type=int, default=1, help='inference t interval step')
# load model
parser.add_argument('--model_arch', type=str, default='transformer', help='Core architecture of diffusion model')
parser.add_argument('--model_channels', type=int, default=768,
help='Try to set it to the same size as the model hidden')
parser.add_argument('--in_channel', type=int, default=768,
help='The input chanel size here must be the same as the word embedding size')
parser.add_argument('--out_channel', type=int, default=768,
help='The dimension size of the output is recommended to be the same as that of word embedding for easy reasoning')
parser.add_argument('--dropout', type=float, default=0.1, help='')
parser.add_argument("--learn_sigma", default=False, action="store_true", help="Whether to learning variance")
parser.add_argument('--logits_mode', type=int, default=1, help='final logits mode of Diffusion model')
parser.add_argument('--vocab_size', type=int, default=30522, help='vocab size')
parser.add_argument('--config_name', type=str, default='bert-base-uncased', help='')
parser.add_argument('--token_emb_type', type=str, default='random', help='token embedding type')
parser.add_argument("--init_pretrained", default=False, action="store_true",
help="Whether to using pretrain BERT encoder")
# load diffusion
# parser.add_argument('--model_arch', type=str, default='transformer', help='Core architecture of diffusion model')
parser.add_argument('--diffusion_steps', type=int, default=2000, help='Diffusion model maximum T')
# parser.add_argument("--learn_sigma", default=False, action="store_true", help="Whether to learning variance")
parser.add_argument('--use_kl', default=False, action="store_true",
help="Whether to using kl loss in Diffsion loss")
parser.add_argument('--training_mode', type=str, default='e2e', help='using e2e simple loss or e2e loss')
parser.add_argument('--noise_schedule', type=str, default='sqrt',
help='How to plan the noise change of Gaussian distribution')
parser.add_argument('--predict_xstart', default=False, action="store_true",
help="Model prediction target, if True, predict xstart, if False, predict EPSILON")
parser.add_argument("--sigma_small", default=False, action="store_true", help="about learning variance")
parser.add_argument("--rescale_learned_sigmas", default=True, action="store_false", help="about learning variance")
parser.add_argument("--rescale_timesteps", default=True, action="store_false", help="about time rescale")
# data args
parser.add_argument('--data_path', type=str, default='', help='data path')
parser.add_argument('--data_name', type=str, default='', help='data name')
# for seq2seq
parser.add_argument('--src_max_len', type=int, default=144, help='src max len')
parser.add_argument('--tgt_max_len', type=int, default=32, help='tgt max len')
parser.add_argument('--answer_max_len', type=int, default=10, help='tgt max len')
# gen args
parser.add_argument('--batch_size', type=int, default=64, help='')
# seed
parser.add_argument('--seed', type=int, default=101, help='')
#
# muti-gpu
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
args = parser.parse_args()
return args
CheckpointState = collections.namedtuple("CheckpointState",
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset'])
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
logger.info('Reading saved model from %s', model_file)
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
logger.info('model_state_dict keys %s', state_dict.keys())
return CheckpointState(**state_dict)
'''
rounding
'''
def denoised_fn_round(args, model, text_emb, t):
# thresh_t = 50
# # print(thresh_t)
# if thresh_t is not None and t[0] > thresh_t:
# return text_emb
if args.model_arch == '1d-unet':
text_emb = text_emb.permute(0, 2, 1)
# return text_emb
# print(t.float().mean(), t[0])
# assert t.float().mean() == t[0].float()
# print(text_emb.shape) # bsz, seqlen, dim
down_proj_emb = model.weight # input_embs
# print(t)
old_shape = text_emb.shape
old_device = text_emb.device
def get_efficient_knn(down_proj_emb, text_emb, dist='l2'):
if dist == 'l2':
emb_norm = (down_proj_emb ** 2).sum(-1).view(-1, 1) # vocab
text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
# print(emb_norm.shape, arr_norm.shape)
dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(down_proj_emb,
text_emb_t) # (vocab, d) x (d, bsz*seqlen)
dist = torch.clamp(dist, 0.0, np.inf)
# print(dist.shape)
topk_out = torch.topk(-dist, k=1, dim=0)
# adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
# down_proj_emb.size(0), -1, -1)
# adjacency = -th.norm(adjacency, dim=-1)
# topk_out = th.topk(adjacency, k=1, dim=0)
# print(topk_out1.indices == topk_out.indices)
# assert th.all(topk_out1.indices == topk_out.indices)
return topk_out.values, topk_out.indices
def get_knn(down_proj_emb, text_emb, dist='l2'):
if dist == 'l2':
adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
down_proj_emb.size(0), -1, -1)
adjacency = -torch.norm(adjacency, dim=-1)
topk_out = torch.topk(adjacency, k=1, dim=0)
return topk_out.values, topk_out.indices
dist = 'l2'
if len(text_emb.shape) > 2:
text_emb = text_emb.reshape(-1, text_emb.size(-1))
else:
text_emb = text_emb
# val, indices = get_knn(down_proj_emb,
# text_emb.to(down_proj_emb.device), dist=dist)
val, indices = get_efficient_knn(down_proj_emb,
text_emb.to(down_proj_emb.device), dist=dist)
rounded_tokens = indices[0]
# print(rounded_tokens.shape)
new_embeds = model(rounded_tokens).view(old_shape).to(old_device)
if args.model_arch == '1d-unet':
new_embeds = new_embeds.permute(0, 2, 1)
return new_embeds
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def setup_env(args):
if args.local_rank == -1:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1
args.device = device
# store args
if args.local_rank != -1:
args.world_size = torch.distributed.get_world_size()
args.rank = dist.get_rank()
def main():
# env setting
args = get_arguments()
# setup_seed(args.seed)
setup_env(args)
if dist.get_rank() == 0:
if not os.path.exists(args.generate_path):
os.makedirs(args.generate_path)
log_path = os.path.join(args.generate_path, 'log')
logger.configure(dir=log_path)
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# define model and diffusion
model, diffusion = create_model_and_diffusion(
args
)
model.to(args.device)
model.eval()
# load trained model
model_saved_state = load_states_from_checkpoint(args.eval_model_path)
model.load_state_dict(model_saved_state.model_dict)
pytorch_total_params = sum(p.numel() for p in model.parameters())
logger.log(f'the parameter count is {pytorch_total_params}')
if dist.get_world_size() > 1:
model = DDP(
model, device_ids=[dist.get_rank()], output_device=dist.get_rank(), find_unused_parameters=False,
)
logger.log("sampling text from random noise...")
print("sample num is :", args.num_samples)
print("sample interval step is :", args.interval_step)
print("total inverse diffusion step is :", 2000 // args.interval_step)
sample_fn = (
diffusion.p_sample_loop
)
if dist.get_world_size() > 1:
emb_model = model.module.word_embedding
else:
emb_model = model.word_embedding
if args.model_arch == 'transformer':
sample_shape = (args.num_samples, args.text_max_len, args.in_channel)
sample = sample_fn(
model,
sample_shape,
clip_denoised=False,
denoised_fn=partial(denoised_fn_round, args, emb_model.cuda()),
model_kwargs=None,
top_p=-1.0,
)
print("sample result shape: ", sample.shape)
print('decoding for e2e... ')
logits = model.get_logits(sample)
cands = torch.topk(logits, k=1, dim=-1)
sample_id_list = cands.indices
print("decode id list example :", type(sample_id_list[0]), " ", sample_id_list[0])
logger.log("creating tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
for sample_id in sample_id_list:
sentence = tokenizer.decode(sample_id.squeeze())
print(sentence)
elif args.model_arch == 's2s_CAT':
# bert tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
print("-------------------------------------------------------------")
print("start generate query from dev dataset, for every passage, we generate ", args.num_samples, " querys...")
print("-------------------------------------------------------------")
print("***** load " + args.data_name + " test src dataset*****")
src = []
test_src_path = os.path.join(args.data_path, args.data_name + "/org_data/test.src")
with open(test_src_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
src.append(text)
print("***** load " + args.data_name + " dev tgt dataset*****")
tgt = []
test_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/test.tgt")
with open(test_tgt_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
tgt.append(text)
shard_size = len(src) // args.world_size
start_idx = args.local_rank * shard_size
end_idx = start_idx + shard_size
if args.local_rank == args.world_size - 1:
end_idx = len(src)
scr_data_piece = src[start_idx:end_idx]
tgt_data_piece = tgt[start_idx:end_idx]
print('generation for ', len(scr_data_piece), " src text from idx ", start_idx, " to ", end_idx)
if args.data_name == "squadqg_data":
test_dataset = QG_dataset_Diff(scr_data_piece, tgt_data_piece, tokenizer, src_maxlength=args.src_max_len,
answer_maxlength=args.answer_max_len, tgt_maxlength=args.tgt_max_len)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, drop_last=False,
num_workers=20, collate_fn=QG_dataset_Diff.get_collate_fn())
else:
test_dataset = S2S_dataset(scr_data_piece, tgt_data_piece, tokenizer, src_maxlength=args.src_max_len,
tgt_maxlength=args.tgt_max_len)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, drop_last=False,
num_workers=20, collate_fn=S2S_dataset.get_collate_fn())
if args.generate_path is not None:
model_gen_files = []
if os.path.exists(args.generate_path):
for item in os.scandir(args.generate_path):
if item.is_file():
if "gen_seed" in item.path:
model_gen_files.append(item.path)
if len(model_gen_files) != 0 :
model_gen_files.sort(key=lambda f: int((f.split('_epoch')[-1]).split('.txt')[0]), reverse=True)
epoch_num = int((model_gen_files[0].split('_epoch')[-1]).split('.txt')[0])
logger.info("***** load " + model_gen_files[0] + " *****")
else:
epoch_num = 0
else:
logger.info("generate_path is None")
exit(0)
for epoch in range(args.num_samples - epoch_num):
each_sample_list = []
print("-------------------------------------------------------------")
print("start sample ", epoch+1+epoch_num, " epoch...")
print("-------------------------------------------------------------")
for index, batch in enumerate(tqdm(test_dataloader)):
'''
for s2s
'''
input_shape = (batch['src_input_ids'].shape[0], args.tgt_max_len, args.in_channel)
src_input_ids = batch['src_input_ids']
tgt_input_ids = batch['tgt_input_ids']
# print(p_input_ids.shape)
src_attention_mask = batch['src_attention_mask']
model_kwargs = {'src_input_ids' : src_input_ids, 'src_attention_mask': src_attention_mask}
sample = sample_fn(
model,
input_shape,
clip_denoised=False,
denoised_fn=partial(denoised_fn_round, args, emb_model.cuda()),
model_kwargs=model_kwargs,
top_p=-1.0,
interval_step=args.interval_step,
)
print("sample result shape: ", sample.shape)
print('decoding for e2e... ')
logits = model.module.get_logits(sample)
cands = torch.topk(logits, k=1, dim=-1)
sample_id_list = cands.indices
#print("decode id list example :", type(sample_id_list[0]), " ", sample_id_list[0])
'''
for s2s
'''
# print("src text: ", tokenizer.decode(src_input_ids.squeeze()))
# print("tgt text: ", tokenizer.decode(tgt_input_ids.squeeze()))
print("sample control generate query: ")
for sample_id in sample_id_list:
sentence = tokenizer.decode(sample_id.squeeze())
each_sample_list.append(clean(sentence))
# print(sentence)
# total_sample_list.append(each_sample_list)
out_path = os.path.join(args.generate_path, "rank" + str(dist.get_rank()) + "_gen_seed_101" +
"_num" + str(args.num_samples) + "_epoch" + str(epoch + 1 + epoch_num) + ".txt")
with open(out_path, 'w') as f:
for sentence in each_sample_list:
f.write(sentence + '\n')
else:
return NotImplementedError
def clean(sentence):
sentence = sentence.replace('[CLS]', '')
sentence = sentence.replace('[SEP]', '')
sentence = sentence.replace('[PAD]', '')
sentence = sentence.replace('[UNK]', 'unk')
return sentence.strip()
if __name__ == "__main__":
main()

220
GENIE/Genie_Pretrain.py Normal file
Просмотреть файл

@ -0,0 +1,220 @@
import argparse
import os
from transformers import set_seed
from diffusion_util.resample import create_named_schedule_sampler
from transformers import AutoTokenizer
import json
from util import logger
from train_util import dist_util
import torch
import torch.distributed as dist
from util.util import (
create_model_and_diffusion,
args_to_dict,
)
from data_util.pretrain_data_util import load_pretrain_data
from torch.serialization import default_restore_location
from train_util.pretrain_util import PretrainLoop
import collections
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
CheckpointState = collections.namedtuple("CheckpointState",
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset'])
def get_arguments():
parser = argparse.ArgumentParser()
# out path
parser.add_argument('--checkpoint_path', type=str, default='', help='output path')
parser.add_argument('--pretrain_model_path', type=str, default=None, help='continue train')
# load model
parser.add_argument('--model_arch', type=str, default='transformer', help='Core architecture of diffusion model')
parser.add_argument('--model_channels', type=int, default=768, help='Try to set it to the same size as the model hidden')
parser.add_argument('--in_channel', type=int, default=768, help='The input chanel size here must be the same as the word embedding size')
parser.add_argument('--out_channel', type=int, default=768, help='The dimension size of the output is recommended to be the same as that of word embedding for easy reasoning')
parser.add_argument('--dropout', type=float, default=0.1, help='')
parser.add_argument("--learn_sigma", default=False, action="store_true", help="Whether to learning variance")
parser.add_argument('--logits_mode', type=int, default=1, help='final logits mode of Diffusion model')
parser.add_argument('--vocab_size', type=int, default=30522, help='vocab size')
parser.add_argument('--config_name', type=str, default='bert-base-uncased', help='')
parser.add_argument('--token_emb_type', type=str, default='random', help='token embedding type')
parser.add_argument("--init_pretrained", default=False, action="store_true", help="Whether to using pretrain BERT encoder")
parser.add_argument("--fix_encoder", default=False, action="store_true",
help="Whether to training encoder")
# load diffusion
parser.add_argument('--diffusion_steps', type=int, default=2000, help='Diffusion model maximum T')
parser.add_argument('--use_kl', default=False, action="store_true", help="Whether to using kl loss in Diffsion loss")
parser.add_argument('--training_mode', type=str, default='e2e', help='using e2e simple loss or e2e loss or s2s loss')
parser.add_argument('--noise_schedule', type=str, default='sqrt', help='How to plan the noise change of Gaussian distribution')
parser.add_argument('--predict_xstart', default=False, action="store_true", help="Model prediction target, if True, predict xstart, if False, predict EPSILON")
parser.add_argument("--sigma_small", default=False, action="store_true", help="about learning variance")
parser.add_argument("--rescale_learned_sigmas", default=True, action="store_false", help="about learning variance")
parser.add_argument("--rescale_timesteps", default=True, action="store_false", help="about time rescale")
# sample t
parser.add_argument('--schedule_sampler', type=str, default='uniform', help='how to sample t per batch, uniform is Uniform sampling, loss-second-moment is Sampling according to loss')
# data args
parser.add_argument('--data_path', type=str, default='',help='data path')
parser.add_argument('--data_name', type=str, default='', help='data name')
# for retrain
parser.add_argument('--pre_max_len', type=int, default=512, help='src max len')
parser.add_argument('--mask_pro', type=float, default=0.3, help='mask pro')
# training args
parser.add_argument('--train_type', type=str, default='LM_Diffusion', help='LM_Diffusion or S2S_Diffusion')
parser.add_argument('--lr_anneal_steps', type=int, default=200000, help='total step')
parser.add_argument('--batch_size', type=int, default=64, help='')
parser.add_argument('--lr', type=float, default=1e-04, help='')
parser.add_argument('--warmup_steps', type=int, default=20000, help='')
parser.add_argument('--ema_rate', type=str, default='0.9999', help='ema training to stable model')
parser.add_argument('--resume_checkpoint', type=str, default=None, help='')
parser.add_argument('--eval_interval', type=int, default=2000, help='')
parser.add_argument('--log_interval', type=int, default=100, help='')
parser.add_argument('--save_interval', type=int, default=50000, help='')
parser.add_argument('--gradient_accumulation_steps', type=int, default=8, help='')
parser.add_argument('--weight_decay', type=str, default=0.0, help='')
parser.add_argument('--gradient_clipping', type=float, default=-1., help='')
parser.add_argument("--use_fp16", default=False, action="store_true", help="about learning variance")
parser.add_argument('--fp16_scale_growth', type=float, default=1e-3, help='')
# seed
parser.add_argument('--seed', type=int, default=101, help='')
# muti-gpu
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
args = parser.parse_args()
return args
def setup_env(args):
if args.local_rank == -1:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1
args.device = device
# store args
if args.local_rank != -1:
args.world_size = torch.distributed.get_world_size()
args.rank = dist.get_rank()
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
logger.info('Reading saved model from %s', model_file)
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
logger.info('model_state_dict keys %s', state_dict.keys())
return CheckpointState(**state_dict)
def main():
# args setting
args = get_arguments()
# out dir set
if dist.get_rank() == 0:
if not os.path.exists(args.checkpoint_path):
os.makedirs(args.checkpoint_path)
# dist.barrier()
logger.log(f'saving the hyperparameters to {args.checkpoint_path}/training_args.json')
with open(f'{args.checkpoint_path}/training_args.json', 'w') as f:
json.dump(args.__dict__, f, indent=2)
# seed setting
# set_seed(args.seed)
# dpp setting
setup_env(args)
# logger setting
log_path = os.path.join(args.checkpoint_path, 'log.txt')
logger.configure(dir=log_path)
model, diffusion = create_model_and_diffusion(
args
)
# load pretrain model to continue train
if args.pretrain_model_path is not None:
print("load model ckpt at :", args.pretrain_model_path)
saved_state = load_states_from_checkpoint(args.pretrain_model_path)
model.load_state_dict(saved_state.model_dict, strict=False)
model.to(args.device)
pytorch_total_params = sum(p.numel() for p in model.parameters())
logger.log(f'the parameter count is {pytorch_total_params}')
'''
time step schedule sampler
'''
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
logger.log("load tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
'''
for s2s
'''
dataname_list = ['book1','book2','book3','book4','book5',
'wiki1', 'wiki2', 'wiki3', 'wiki4', 'wiki5',
'stories1','stories2','stories3','stories4','stories5',
'openweb1','openweb2','openweb3','openweb4','openweb5',
'realnews1', 'realnews2', 'realnews3', 'realnews4', 'realnews5',
'realnews6', 'realnews7', 'realnews8', 'realnews9','realnews10']
# roll data list
start_index = 0
'''
training
'''
logger.log("pretraining Diffusion LM model using ")
PretrainLoop(
# training type
train_type=args.train_type,
# Training Core
model=model,
diffusion=diffusion,
data=dataname_list,
eval_data=None,
schedule_sampler=schedule_sampler,
checkpoint_path=args.checkpoint_path,
# Training Parameters
batch_size=args.batch_size,
lr=args.lr,
ema_rate=args.ema_rate,
weight_decay=args.weight_decay,
lr_anneal_steps=args.lr_anneal_steps,
gradient_clipping=args.gradient_clipping,
# fp16
use_fp16=args.use_fp16,
fp16_scale_growth=args.fp16_scale_growth,
# Training Log
resume_checkpoint=args.resume_checkpoint,
eval_interval=args.eval_interval,
log_interval=args.log_interval,
save_interval=args.save_interval,
gradient_accumulation_steps=args.gradient_accumulation_steps,
# device
device=args.device,
args=args,
tokenizer=tokenizer,
).run_loop()
if __name__ == "__main__":
main()

160
GENIE/README.md Normal file
Просмотреть файл

@ -0,0 +1,160 @@
# GENIE
This repo provides the code and models for [Text Generation with Diffusion Language Models: A Pre-training Approach with Continuous Paragraph Denoise](https://arxiv.org/abs/2212.11685).
## 🚀 Overview
In this paper, we introduce a novel d**I**ffusion language mod**E**l pre-training framework for text **GEN**eration, which we call **GENIE**. GENIE is a large-scale pretrained diffusion language model that consists of an encoder and a diffusion-based decoder, which can generate text by gradually transforming a random noise sequence into a coherent text sequence.
<div align=center><img src="image\GENIE.png" width = "600" height = 300/></div>
To pre-train GENIE on a large-scale language corpus, we design a novel pre-training method called *continuous paragraph denoise* (CPD), which encourages the diffusion-decoder to reconstruct a clean text paragraph from a corrupted version, while preserving the semantic and syntactic coherence.
You can find more details in the [paper](https://arxiv.org/abs/2212.11685).
## ⚙️ Experiment Preparation
**Dependencies: **
- python>=3.6
- torch>=1.7.1
- datasets>=1.12.1
- transformers>=4.9.2 (Huggingface)
- pyrouge==0.1.3
**Downstream Task Dataset:**
The text generation benchmarks we use is well-known and widely used, including *XSum*, *CNN/DailyMail*, and *GigaWord*. You can find more detailed information and obtain methods of the dataset [here](https://microsoft.github.io/glge/).
**Model**
We have released the checkpoint of the GENIE after pre-training on `160G` corpus (6-layer encoder, and 6-layer decoder):
- **GENIE V1** [[link](https://drive.google.com/file/d/1-AZssEmgs0QdTp_w8-_4cPi0cV-Hot4N/view?usp=share_link)]
You can also quickly get the GENIE checkpoints fine-tuned on the *XSum*, *CNN/DailyMail*, and *GigaWord* here:
- GENIE XSum [[link](https://drive.google.com/file/d/1-3NJwuDbSV00TwYs5FqG5cHvCY10CW0h/view?usp=share_link)]
- GENIE CNN/DailyMail [[link](https://drive.google.com/file/d/1-6shROw2TLWPTMLQbESmhQzfI0Z3pAOm/view?usp=share_link)]
- GENIE GigaWord [[link](https://drive.google.com/file/d/1-7PoPTX0w4Q_Sh4qrxB1WQId1tBCydY-/view?usp=share_link)]
We will continue to update and optimize this repo in the future.
## 💡 Pre-training
In the pre-training process, we use pre-training data consisting of `160Gb` of news, books, stories, and web text. We trained **GENIE V1** on `8 * 40G` `A100` for `50 days`. If you are interested in our pre-training process, please refer to `Genie_Pretrain.py`. Here we provide the pre-training running script for reference:
```shell
OUT_DIR = "/Your/output/path"
DATA_PATH = "/Your/pretrain/data/path"
python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9489 \
./GENIE_main/Genie_Pretrain.py \
--checkpoint_path=$OUT_DIR \
--model_channels 128 --in_channel 128 --out_channel 128 --vocab_size 30522 \
--config_name="bert-base-uncased" --token_emb_type="random" --model_arch="s2s_CAT" \
--diffusion_steps 2000 --predict_xstart --noise_schedule="sqrt" --training_mode="s2s" \
--schedule_sampler="uniform" --pre_max_len 512 --mask_pro 0.3 --seed 2023 \
--data_path=$DATA_PATH \
--batch_size 64 --lr 1e-04 --warmup_steps 300000 --train_type="S2S_Diffusion" \
--eval_interval 2000 --log_interval 2000 --save_interva 20000
```
The pre-training of diffusion model needs careful parameter adjustment and reasonable training configuration, especially the dimension of vocab size and input/output channel, which is worth our constant exploration.
## ⚽ Fine-tuning
In this section, we will use *XSum* dataset as an example to demonstrate the process of GENIE fine-tuning on downstream tasks. The running script for fine-tuning is as follows:
```shell
OUT_DIR = "/Your/output/path"
DATA_PATH = "/Your/data/path"
DATA_NAME = "xsum_data"
PRETRAIN_CKPT_PATH = "/Your/pretrain_ckpt/path"
python -u -m torch.distributed.launch --nproc_per_node=4 --master_port=9421 \
./GENIE_main/Genie_Finetune.py \
--checkpoint_path=$OUT_DIR \
--model_channels 128 --in_channel 128 --out_channel 128 --vocab_size 30522 \
--config_name="bert-base-uncased" --token_emb_type="random" --model_arch="s2s_CAT" \
--diffusion_steps 2000 --predict_xstart --noise_schedule="sqrt" --training_mode="s2s" \
--schedule_sampler="loss-second-moment" --tgt_max_len 64 --src_max_len 512 --data_name=$DATA_NAME \
--data_path=$DATA_PATH \
--lr_anneal_steps 120000 --batch_size 64 --lr 5e-05 --warmup_steps 7200 --train_type="S2S_Diffusion" \
--eval_interval 200 --log_interval 200 --save_interva 20000 \
--pretrain_model_path=$PRETRAIN_CKPT_PATH
```
Important parameter setting:
- `--checkpoint_path`: Location of model checkpoints and log file output after fine-tuning.
- `--data_path`: Overall catalog of downstream task datasets.
- `--data_name`: Name of downstream task dataset, Make sure your data is in the directory composed of "data_path + data_name", The directory needs to contain data files: `train.src`, `train.tgt`, `dev.src`, `dev.tgt`, `test.src`, `test.tgt`.
- `--pretrain_model_path`: GENIE checkpoint path after pre-training.
If you need to replace the fine-tuning task, you just need to organize the data into the required form according to the standard format, such as *CNN/DailyMail*, and *GigaWord*, change `DATA_NAME` to `cnndm_data` or `gigaword_data`.
If you need to train from scratch (w/o pre-train), just remove the parameter `--pretrain_model_path`.
## 💬 Generate
In this section, we will show how to batch generate text from trained GENIE. We need to sample the Gaussian noise and iteratively denoise it with GENIE to restore the text. The running script for generating is as follows:
```shell
OUT_DIR = "/Your/output/path"
MODEL_DIR = "/Your/model/ckpt/path"
DATA_PATH = "/Your/data/path"
DATA_NAME = "xsum_data"
python -u -m torch.distributed.launch --nproc_per_node=8 --master_port=9498 \
./GENIE_main/Genie_Generate.py \
--generate_path=$OUT_DIR \
--eval_model_path=$MODEL_DIR \
--data_path=$DATA_PATH \
--model_channels 128 --in_channel 128 --out_channel 128 --vocab_size 30522 \
--config_name="bert-base-uncased" --token_emb_type="random" \
--diffusion_steps 2000 --predict_xstart --noise_schedule="sqrt" \
--num_samples 5 --model_arch="s2s_CAT" --data_name=$DATA_NAME \
--training_mode="s2s" --tgt_max_len 64 --src_max_len 512 --batch_size=200 \
--interval_step 1 --seed 2023
```
Important parameter setting:
- `--generate_path`: Output location of generated text.
- `--eval_model_path`: Model checkpoint path needed for generation after training.
- `--data_path`: Overall catalog of downstream task datasets
- `--data_name`: Name of downstream task dataset, Make sure your data is in the directory composed of "data_path + data_name", The directory needs to contain data files: `train.src`, `train.tgt`, `dev.src`, `dev.tgt`, `test.src`, `test.tgt`.
- `--num_samples`: The number of Gaussian noise samples per sample. (the number of text generated per sample)
- `--interval_step`: Interval steps for denoise, default set to 1.
You can adjust `--batch_size` and parallel GPUs (`--nproc_per_node`) based on the performance of the device you are using. The name of the resulting text file is formatted as `rank[gpu_id]_gen_seed_[seed]_num[num_samples]_epoch[sample epoch].txt`.
Ultimately, we need to integrate the generated text, running the script as follows:
```shell
OUT_DIR = "/Your/output/path"
DATA_PATH = "/Your/data/path"
DATA_NAME = "xsum_data"
python ./GENIE_main/integration/eval_split.py \
--generate_path=$OUT_DIR \
--data_path=$DATA_PATH \
--num_samples 5 --data_name=$DATA_NAME --n_gpu 8 --seed 2023
```
Note that the above parameter settings need to be consistent with the generated parameter settings, and the optimal results will be saved in the `--generate_path`. If you want to reproduce the results in the paper, please use the GLGE official evaluation method [here](https://github.com/microsoft/ProphetNet/tree/master/GLGE_baselines).
## 📜 Citation
Please cite our paper if you use [GENIE](https://arxiv.org/abs/2212.11685) in your work:
```bibtex
@article{lin2022genie,
title = {Text Generation with Diffusion Language Models: A Pre-training Approach with Continuous Paragraph Denoise},
author = {Zhenghao Lin, Yeyun Gong, Yelong Shen, Tong Wu, Zhihao Fan, Chen Lin, Nan Duan, Weizhu Chen},
booktitle = {{arXiv}},
year = {2022}
}
```

Просмотреть файл

Просмотреть файл

@ -0,0 +1,224 @@
from tqdm import tqdm
import os
from transformers import AutoTokenizer
from torch.utils.data.dataset import Dataset
import torch
import jsonlines
import numpy as np
import random
def load_loop_pretrain_data(args, padding_mode, tokenizer, data_name = None):
print("***** load " + data_name + " train src dataset*****")
path = os.path.join(args.data_path, data_name + '.npy')
input_id_list = np.load(path, allow_pickle=True)
# filter
input_id_list = np.array([input_id for input_id in input_id_list if np.count_nonzero(input_id) >= 30])
if padding_mode == 'max_len':
dataset = Pre_dataset(input_id_list, tokenizer, mask_pro=args.mask_pro, maxlength=args.pre_max_len)
elif padding_mode == 'conti_tgt':
print("using new pretrain method...")
dataset = Pre_dataset_type2(input_id_list, tokenizer, mask_pro=args.mask_pro, maxlength=args.pre_max_len)
elif padding_mode == 'block':
print("padding block is under realization")
pass
else:
return NotImplementedError
print("example of src id lists: ", dataset[50][0])
print("example of tgt id lists: ", dataset[50][1])
print("total query dataset len :", len(dataset))
return dataset
def load_pretrain_data(args, padding_mode, tokenizer, data_name = None):
questions = []
print("***** load " + data_name + " train src dataset*****")
input_id_list = None
if data_name == "book" or data_name == "openweb" or data_name == "wiki" or data_name == "stories":
for i in range(5):
path = os.path.join(args.data_path, args.data_name + str(i+1) + '.npy')
input_id_list_pre = np.load(path, allow_pickle=True)
if i == 0:
input_id_list = input_id_list_pre
else:
input_id_list = np.concatenate((input_id_list, input_id_list_pre), axis=0)
# with open(path, "r", encoding="utf-8") as ifile:
# for line in tqdm(ifile):
# line = line.strip()
# text = line
# tgt.append(text)
elif data_name == 'realnews':
# for i in range(10):
# path = os.path.join(args.data_path, args.data_name + str(i+1) + '.txt')
# with open(path, "r", encoding="utf-8") as ifile:
# for line in tqdm(ifile):
# line = line.strip()
# text = line
# tgt.append(text)
for i in range(10):
path = os.path.join(args.data_path, args.data_name + str(i+1) + '.npy')
input_id_list_pre = np.load(path, allow_pickle=True)
if i == 0:
input_id_list = input_id_list_pre
else:
input_id_list = np.concatenate((input_id_list, input_id_list_pre), axis=0)
else:
return NotImplementedError
# filter
input_id_list = np.array([input_id for input_id in input_id_list if np.count_nonzero(input_id) >= 256])
# print("example of src text: ", src[50])
print("example of input id: ", input_id_list[50])
if padding_mode == 'max_len':
dataset = Pre_dataset(input_id_list, tokenizer, mask_pro=args.mask_pro, maxlength=args.pre_max_len)
elif padding_mode == 'conti_tgt':
print("using new pretrain method...")
dataset = Pre_dataset_type2(input_id_list, tokenizer, mask_pro=args.mask_pro, maxlength=args.pre_max_len)
elif padding_mode == 'block':
print("padding block is under realization")
pass
else:
return NotImplementedError
print("example of src id lists: ", dataset[50][0])
print("example of tgt id lists: ", dataset[50][1])
print("total query dataset len :", len(dataset))
return dataset
class Pre_dataset(Dataset):
def __init__(self, tgt_id, tokenizer, mask_pro=0.3, maxlength=512, span_size=8, mask_mode='random'):
self.tgt_id = tgt_id
self.tokenizer = tokenizer
self.maxlength = maxlength
self.mask_pro = mask_pro
self.span_size = span_size
self.mask_token_index = self.tokenizer.mask_token_id
self.pad_token_index = self.tokenizer.pad_token_id
self.all_special_token = self.tokenizer.all_special_ids
def __getitem__(self, index):
tgt_example = self.tgt_id[index]
# src_input_ids = tgt_example.tolist()
tgt_input_ids = (torch.from_numpy(tgt_example)).long()
src_input_ids = tgt_input_ids.clone()
id_len = torch.nonzero(src_input_ids).shape[0]
mask_span_num = int((id_len * self.mask_pro) // self.span_size) + 1
# print("mask_span_num:", mask_span_num)
mask_index = torch.randint(0, id_len, (mask_span_num,))
# print("mask_index:", mask_index)
mask_id_mask = torch.full(src_input_ids.shape, False, dtype=torch.bool)
retain_id_mask = torch.full(src_input_ids.shape, True, dtype=torch.bool)
mask_id_mask[mask_index] = True
del_index = mask_index.tolist()
for i in mask_index:
del_index.extend(list(range(i + 1, i + self.span_size)))
del_index = [i for i in del_index if i < id_len]
del_index = torch.from_numpy(np.array(list(set(del_index))))
# print("del_index", del_index)
retain_id_mask[del_index] = False
retain_id_mask = retain_id_mask | mask_id_mask
src_input_ids[mask_id_mask] = self.mask_token_index
src_input_ids = src_input_ids[retain_id_mask].tolist()
# print("src_input_ids1:", len(src_input_ids))
src_input_ids = src_input_ids + [self.pad_token_index] * (self.maxlength - len(src_input_ids))
# print("src_input_ids2:", len(src_input_ids))
src_input_ids = torch.from_numpy(np.array(src_input_ids)).long()
return src_input_ids.unsqueeze(0), tgt_input_ids.unsqueeze(0)
def __len__(self):
return len(self.tgt_id)
@classmethod
def get_collate_fn(cls):
def fn(features):
src_tensor = torch.cat([feature[0] for feature in features])
tgt_tensor = torch.cat([feature[1] for feature in features])
return { "src_input_ids": src_tensor, "src_attention_mask": (src_tensor != 0).long(),
"tgt_input_ids": tgt_tensor, "tgt_attention_mask": (tgt_tensor != 0).long() }
return fn
class Pre_dataset_type2(Dataset):
def __init__(self, tgt_id, tokenizer, mask_pro=0.3, maxlength=512, mask_mode='random'):
self.tgt_id = tgt_id
self.tokenizer = tokenizer
self.maxlength = maxlength
self.mask_pro = mask_pro
self.tgtmaxlength = int(maxlength * mask_pro) + 1
self.mask_token_index = self.tokenizer.mask_token_id
self.pad_token_index = self.tokenizer.pad_token_id
self.all_special_token = self.tokenizer.all_special_ids
def __getitem__(self, index):
tgt_example = self.tgt_id[index]
# src_input_ids = tgt_example.tolist()
tgt_input_ids = (torch.from_numpy(tgt_example)).long()
src_input_ids = tgt_input_ids.clone()
id_len = torch.nonzero(src_input_ids).shape[0]
# mask_span_num = int((id_len * self.mask_pro) // self.span_size) + 1
mask_span_len = int(id_len * self.mask_pro)
# print("mask_span_num:", mask_span_num)
mask_index = random.randint(0, id_len-mask_span_len-1)
tgt_input_ids = src_input_ids.tolist()[mask_index:mask_index+mask_span_len]
src_input_ids[mask_index] = self.mask_token_index
# print("mask_index:", mask_index)
# mask_span_len
# mask_id_mask = torch.full(src_input_ids.shape, False, dtype=torch.bool)
retain_id_mask = torch.full(src_input_ids.shape, True, dtype=torch.bool)
# mask_id_mask[mask_index] = True
# del_index = mask_index.tolist()
del_index = list(range(mask_index + 1, mask_index + mask_span_len))
del_index = torch.from_numpy(np.array(del_index))
retain_id_mask[del_index] = False
# src_input_ids[mask_id_mask] = self.mask_token_index
src_input_ids = src_input_ids[retain_id_mask].tolist()
# print("src_input_ids1:", len(src_input_ids))
src_input_ids = src_input_ids + [self.pad_token_index] * (self.maxlength - len(src_input_ids))
# print("src_input_ids2:", len(src_input_ids))
tgt_input_ids = tgt_input_ids + [self.pad_token_index] * (self.tgtmaxlength - len(tgt_input_ids))
src_input_ids = torch.from_numpy(np.array(src_input_ids)).long()
tgt_input_ids = torch.from_numpy(np.array(tgt_input_ids)).long()
return src_input_ids.unsqueeze(0), tgt_input_ids.unsqueeze(0)
def __len__(self):
return len(self.tgt_id)
@classmethod
def get_collate_fn(cls):
def fn(features):
src_tensor = torch.cat([feature[0] for feature in features])
tgt_tensor = torch.cat([feature[1] for feature in features])
return { "src_input_ids": src_tensor, "src_attention_mask": (src_tensor != 0).long(),
"tgt_input_ids": tgt_tensor, "tgt_attention_mask": (tgt_tensor != 0).long() }
return fn
if __name__ == "__main__":
pretrain_max_len = 512

Просмотреть файл

@ -0,0 +1,408 @@
from tqdm import tqdm
import os
from transformers import AutoTokenizer
from torch.utils.data.dataset import Dataset
import torch
import jsonlines
def load_s2s_data(args, padding_mode, split, tokenizer):
questions = []
if split == 'train':
print("***** load " + args.data_name + " train src dataset*****")
src = []
train_src_path = os.path.join(args.data_path, args.data_name + "/org_data/train.src")
with open(train_src_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
src.append(text)
print("***** load " + args.data_name + " train tgt dataset*****")
tgt = []
train_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/train.tgt")
with open(train_tgt_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
tgt.append(text)
elif split == 'dev':
print("***** load " + args.data_name + " dev src dataset*****")
src = []
dev_src_path = os.path.join(args.data_path, args.data_name + "/org_data/dev.src")
with open(dev_src_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
src.append(text)
print("***** load " + args.data_name + " dev tgt dataset*****")
tgt = []
dev_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/dev.tgt")
with open(dev_tgt_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
tgt.append(text)
elif split == 'test':
print("***** load " + args.data_name + " test src dataset*****")
src = []
test_src_path = os.path.join(args.data_path, args.data_name + "/org_data/test.src")
with open(test_src_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
src.append(text)
print("***** load " + args.data_name + " dev tgt dataset*****")
tgt = []
test_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/test.tgt")
with open(test_tgt_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
tgt.append(text)
else:
print("no such split of data...")
exit(0)
print("example of src text: ", src[50])
print("example of tgt text: ", tgt[50])
if padding_mode == 'max_len':
if args.data_name == "squadqg_data":
dataset = QG_dataset_Diff(src, tgt, tokenizer, src_maxlength=args.src_max_len,
answer_maxlength=args.answer_max_len, tgt_maxlength=args.tgt_max_len)
else:
dataset = S2S_dataset(src, tgt, tokenizer, src_maxlength=args.src_max_len, tgt_maxlength=args.tgt_max_len)
elif padding_mode == 'block':
print("padding block is under realization")
pass
else:
return NotImplementedError
print("example of src id lists: ", dataset[50][0])
print("example of tgt id lists: ", dataset[50][1])
print("total query dataset len :", len(dataset))
return dataset
'''
for AR seq2seq training
'''
def load_s2s_data_AR(args, padding_mode, split, tokenizer):
if split == 'train':
print("***** load " + args.data_name + " train src dataset*****")
src = []
train_src_path = os.path.join(args.data_path, args.data_name + "/org_data/train.src")
with open(train_src_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
src.append(text)
print("***** load " + args.data_name + " train tgt dataset*****")
tgt = []
train_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/train.tgt")
with open(train_tgt_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
tgt.append(text)
elif split == 'dev':
print("***** load " + args.data_name + " dev src dataset*****")
src = []
dev_src_path = os.path.join(args.data_path, args.data_name + "/org_data/dev.src")
with open(dev_src_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
src.append(text)
print("***** load " + args.data_name + " dev tgt dataset*****")
tgt = []
dev_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/dev.tgt")
with open(dev_tgt_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
tgt.append(text)
# src = src[:100]
# tgt = tgt[:100]
elif split == 'test':
print("***** load " + args.data_name + " test src dataset*****")
src = []
test_src_path = os.path.join(args.data_path, args.data_name + "/org_data/test.src")
with open(test_src_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
src.append(text)
print("***** load " + args.data_name + " dev tgt dataset*****")
tgt = []
test_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/test.tgt")
with open(test_tgt_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
tgt.append(text)
# src = src[:10]
# tgt = tgt[:10]
else:
print("no such split of data...")
exit(0)
print("example of src text: ", src[9].replace("<S_SEP>",'\n'))
print("example of tgt text: ", tgt[9].replace("<S_SEP>",'\n'))
if padding_mode == 'max_len':
dataset = S2S_AR_dataset(src, tgt, tokenizer, src_maxlength=args.src_max_len, tgt_maxlength=args.tgt_max_len)
elif padding_mode == 'block':
print("padding block is under realization")
pass
else:
return NotImplementedError
print("example of src id lists: ", dataset[9][0])
print("example of tgt id lists: ", dataset[9][1])
print("total query dataset len :", len(dataset))
return dataset
'''
load baseline data on SeqDiffusion
'''
def load_s2s_jsonl_data(args, padding_mode, split, tokenizer):
questions = []
if split == 'train':
print("***** load " + args.data_name + " train src and tgt dataset*****")
src = []
tgt = []
train_path = os.path.join(args.data_path, args.data_name + "/train.jsonl")
with jsonlines.open(train_path) as reader:
for obj in reader:
tgt.append(obj['trg'])
src.append(obj['src'])
elif split == 'dev':
print("***** load " + args.data_name + " dev src and tgt dataset*****")
src = []
tgt = []
dev_path = os.path.join(args.data_path, args.data_name + "/valid.jsonl")
with jsonlines.open(dev_path) as reader:
for obj in reader:
tgt.append(obj['trg'])
src.append(obj['src'])
elif split == 'test':
print("***** load " + args.data_name + " test src and tgt dataset*****")
src = []
tgt = []
test_path = os.path.join(args.data_path, args.data_name + "/test.jsonl")
with jsonlines.open(test_path) as reader:
for obj in reader:
tgt.append(obj['trg'])
src.append(obj['src'])
else:
print("no such split of data...")
exit(0)
print("example of src text: ", src[50])
print("example of tgt text: ", tgt[50])
if padding_mode == 'max_len':
dataset = S2S_dataset(src, tgt, tokenizer, src_maxlength=args.src_max_len, tgt_maxlength=args.tgt_max_len)
elif padding_mode == 'block':
print("padding block is under realization")
pass
else:
return NotImplementedError
print("example of src id lists: ", dataset[50][0])
print("example of tgt id lists: ", dataset[50][1])
print("total query dataset len :", len(dataset))
return dataset
class QG_dataset_Diff(Dataset):
def __init__(self, src, tgt, tokenizer, src_maxlength=144, answer_maxlength=20, tgt_maxlength=32):
self.src = src
self.tgt = tgt
self.tokenizer = tokenizer
self.src_maxlength = src_maxlength
self.tgt_maxlength = tgt_maxlength
self.ans_maxlength = answer_maxlength
def __getitem__(self, index):
src_example = self.src[index]
tgt_example = self.tgt[index]
answer = src_example.split('[SEP]')[0].strip()
passage = src_example.split('[SEP]')[1].strip()
src_input_ids = self.tokenizer.encode(passage, add_special_tokens=True,
max_length=self.src_maxlength, truncation=True,
padding='max_length',return_tensors='pt')
answer_ids = self.tokenizer.encode(answer, add_special_tokens=True,
max_length=self.ans_maxlength, truncation=True,
padding='max_length', return_tensors='pt')
tgt_input_ids = self.tokenizer.encode(tgt_example, add_special_tokens=True,
max_length=self.tgt_maxlength, truncation=True,
padding='max_length', return_tensors='pt')
return src_input_ids, answer_ids, tgt_input_ids
def __len__(self):
return len(self.src)
@classmethod
def get_collate_fn(cls):
def fn(features):
src_tensor = torch.cat([feature[0] for feature in features])
ans_tensor = torch.cat([feature[1] for feature in features])
tgt_tensor = torch.cat([feature[2] for feature in features])
return { "src_input_ids": src_tensor, "src_attention_mask": (src_tensor != 0).long(),
"answer_ids": ans_tensor, "answer_mask": (ans_tensor != 0).long(),
"tgt_input_ids": tgt_tensor, "tgt_attention_mask": (tgt_tensor != 0).long() }
return fn
'''
s2s for AR model
'''
class S2S_AR_dataset(Dataset):
def __init__(self, src, tgt, tokenizer, src_maxlength=144, tgt_maxlength=32):
self.src = src
self.tgt = tgt
self.tokenizer = tokenizer
self.src_maxlength = src_maxlength
self.tgt_maxlength = tgt_maxlength
def __getitem__(self, index):
src_example = self.src[index]
tgt_example = self.tgt[index]
src_example.replace('<S_SEP>', '\n')
tgt_example.replace('<S_SEP>', '\n')
src_input_ids = self.tokenizer.encode(src_example, add_special_tokens=True,
max_length=self.src_maxlength, truncation=True,
padding='max_length', return_tensors='pt')
tgt_input_ids = self.tokenizer.encode(tgt_example, add_special_tokens=True,
max_length=self.tgt_maxlength, truncation=True,
padding='max_length', return_tensors='pt')
tgt_input_ids[tgt_input_ids == 1] = -100
return src_input_ids, tgt_input_ids
def __len__(self):
return len(self.src)
@classmethod
def get_collate_fn(cls):
def fn(features):
src_tensor = torch.cat([feature[0] for feature in features])
tgt_tensor = torch.cat([feature[1] for feature in features])
# print("src shape:", src_tensor.shape)
# print("tgt shape:", tgt_tensor.shape)
return { "input_ids": src_tensor, "attention_mask": (src_tensor != 0).long(),
"labels": tgt_tensor}
return fn
class S2S_dataset(Dataset):
def __init__(self, src, tgt, tokenizer, src_maxlength=144, tgt_maxlength=32):
self.src = src
self.tgt = tgt
self.tokenizer = tokenizer
self.src_maxlength = src_maxlength
self.tgt_maxlength = tgt_maxlength
def __getitem__(self, index):
src_example = self.src[index]
tgt_example = self.tgt[index]
src_input_ids = self.tokenizer.encode(src_example, add_special_tokens=True,
max_length=self.src_maxlength, truncation=True,
padding='max_length',return_tensors='pt')
tgt_input_ids = self.tokenizer.encode(tgt_example, add_special_tokens=True,
max_length=self.tgt_maxlength, truncation=True,
padding='max_length', return_tensors='pt')
return src_input_ids, tgt_input_ids
def __len__(self):
return len(self.src)
@classmethod
def get_collate_fn(cls):
def fn(features):
src_tensor = torch.cat([feature[0] for feature in features])
tgt_tensor = torch.cat([feature[1] for feature in features])
return { "src_input_ids": src_tensor, "src_attention_mask": (src_tensor != 0).long(),
"tgt_input_ids": tgt_tensor, "tgt_attention_mask": (tgt_tensor != 0).long() }
return fn
class S2S_imp_dataset(Dataset):
def __init__(self, src, tgt, ori_gen, tokenizer, src_maxlength=144, tgt_maxlength=32):
self.src = src
self.tgt = tgt
self.ori_gen = ori_gen
self.tokenizer = tokenizer
self.src_maxlength = src_maxlength
self.tgt_maxlength = tgt_maxlength
def __getitem__(self, index):
src_example = self.src[index]
tgt_example = self.tgt[index]
ori_gen_example = self.ori_gen[index]
src_input_ids = self.tokenizer.encode(src_example, add_special_tokens=True,
max_length=self.src_maxlength, truncation=True,
padding='max_length',return_tensors='pt')
tgt_input_ids = self.tokenizer.encode(tgt_example, add_special_tokens=True,
max_length=self.tgt_maxlength, truncation=True,
padding='max_length', return_tensors='pt')
ori_gen_ids = self.tokenizer.encode(ori_gen_example, add_special_tokens=True,
max_length=self.tgt_maxlength, truncation=True,
padding='max_length', return_tensors='pt')
return src_input_ids, tgt_input_ids, ori_gen_ids
def __len__(self):
return len(self.src)
@classmethod
def get_collate_fn(cls):
def fn(features):
src_tensor = torch.cat([feature[0] for feature in features])
tgt_tensor = torch.cat([feature[1] for feature in features])
ori_gen_tensor = torch.cat([feature[2] for feature in features])
return { "src_input_ids": src_tensor, "src_attention_mask": (src_tensor != 0).long(),
"tgt_input_ids": tgt_tensor, "tgt_attention_mask": (tgt_tensor != 0).long(),
"ori_gen_ids": ori_gen_tensor}
return fn
if __name__ == "__main__":
pass

Просмотреть файл

@ -0,0 +1,343 @@
import os
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from util import logger
from train_util import dist_util
import numpy as np
def load_data_text(data_args, emb_model=None, padding_mode='max_len', split='train', tokenizer=None, return_hidden=False):
'''
init embedding model
'''
if split == 'train':
if data_args.token_emb_type == 'random' and emb_model is None:
print('loading initialized random embeddings. ')
emb_model = None
elif data_args.token_emb_type == 'pretrain' and emb_model is not None:
print('loading embeddings from pretraining embedding. ')
'''
load query text dataset (only query) according to split
'''
if data_args.model_arch == 'transformer':
dataset = get_query_corpus(data_args, padding_mode=padding_mode, split=split,
tokenizer=tokenizer)
elif data_args.model_arch == 's2s_CAT':
dataset = get_pandq_corpus(data_args, padding_mode=padding_mode, split=split,
tokenizer=tokenizer)
else:
return NotImplementedError
'''
load embedding model
'''
# if data_args.token_emb_type == 'random' and emb_model is None:
# emb_model = torch.nn.Embedding(tokenizer.vocab_size, data_args.in_channel)
# print('initializing the random embeddings', emb_model)
# torch.nn.init.normal_(emb_model.weight)
# path_save = f'{data_args.checkpoint_path}/random_emb.torch'
# print(f'save the random encoder to {data_args.checkpoint_path}/random_emb.torch')
# torch.save(emb_model.state_dict(), path_save)
# elif data_args.token_emb_type == 'pretrain':
# # TODO: finish pretrain embedding setting
# print('initializing the pretrain embeddings', emb_model)
# else:
# return NotImplementedError
if return_hidden:
'''
encode input ids to input embedding
load temp dataloader to input in embedding model, getting hidden_state (word embedding)
'''
query_dataloader = DataLoader(dataset, batch_size=1024, drop_last=False,
num_workers=20, collate_fn=Question_dataset.get_collate_fn())
emb_model.to(dist_util.dev())
hidden_state_set = []
with torch.no_grad():
for k, (ids, text_ids, text_mask) in enumerate(tqdm(query_dataloader)):
hidden_state = emb_model(text_ids.long().to(dist_util.dev()))
# hidden_state :(batch_size, seq_len, hidden_size)
hidden_state = hidden_state.detach().cpu()
hidden_state_set.append(hidden_state)
hidden_state_set = torch.cat(hidden_state_set, dim=0)
'''
load query embedding dataloader
'''
query_emb_dataset = Text_Hidden_dataset(dataset, hidden_state_set)
# query_dataloader = DataLoader(query_emb_dataset, batch_size=data_args.batch_size, drop_last=False,
# num_workers=20, collate_fn=Text_Hidden_dataset.get_collate_fn())
emb_model.cpu()
return query_emb_dataset, emb_model
else:
return dataset, emb_model
# for input_ids in dataset['word_ids']:
# if data_args.experiment.startswith('random'):
# hidden_state = model(torch.tensor(input_ids))
# elif data_args.experiment == 'gpt2_pre_compress':
# input_ids2 = torch.tensor(input_ids).to(model.device)
# input_embs = model.transformer.wte(input_ids2) # input_embs
# hidden_state = model.down_proj(input_embs)
# hidden_state = hidden_state * data_args.emb_scale_factor
# elif data_args.experiment == 'glove':
# hidden_state = model(torch.tensor(input_ids))
# result_train_lst.append({'input_ids': input_ids, 'hidden_states': hidden_state.cpu().tolist()})
def get_query_corpus(args, padding_mode, split, tokenizer):
questions = []
if split == 'train':
print("***** load train query dataset*****")
train_question_path = os.path.join(args.data_path, "train.query.txt")
with open(train_question_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
id, text = line.split('\t')
questions.append([int(id), text])
elif split == 'dev':
dev_question_path = os.path.join(args.data_path, "dev.query.txt")
with open(dev_question_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
id, text = line.split('\t')
questions.append([int(id), text])
else:
print("no such split of data...")
exit(0)
print("example of questions text: ", questions[50])
if padding_mode == 'max_len':
dataset = Question_dataset(questions, tokenizer, maxlength=args.text_max_len)
elif padding_mode == 'block':
print("padding block is under realization")
pass
else:
return NotImplementedError
print("example of questions id lists: ", dataset[50])
print("total query dataset len :", len(dataset))
return dataset
def get_pandq_corpus(args, padding_mode, split, tokenizer):
questions = []
if split == 'train':
print("***** load train query dataset*****")
train_question_path = os.path.join(args.data_path, "train.query.txt")
with open(train_question_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
id, text = line.split('\t')
questions.append([int(id), text])
print("***** load query to passage qrel data*****")
qrel_path = os.path.join(args.data_path, "qrels.train.tsv")
qids_to_relevant_passageids = load_train_reference_from_stream(qrel_path)
elif split == 'dev':
dev_question_path = os.path.join(args.data_path, "dev.query.txt")
with open(dev_question_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
id, text = line.split('\t')
questions.append([int(id), text])
print("***** load query to passage qrel data*****")
qrel_path = os.path.join(args.data_path, "qrels.dev.tsv")
qids_to_relevant_passageids = load_dev_reference_from_stream(qrel_path)
else:
print("no such split of data...")
exit(0)
print("***** load passages dataset*****")
passage_title_path = os.path.join(args.data_path, "para.title.txt")
passage_ctx_path = os.path.join(args.data_path, "para.txt")
passage_title = load_id_text(passage_title_path)
passages = {}
with open(passage_ctx_path) as inp:
for line in tqdm(inp):
line = line.strip()
id, text = line.split('\t')
passages[int(id)] = (text, passage_title.get(id, '-'))
print("example of questions text: ", questions[50])
qid = questions[50][0]
pid = qids_to_relevant_passageids[qid][0]
print("example of passage text --title: ", passages[pid][1])
print("example of passage text --passage: ", passages[pid][0])
if padding_mode == 'max_len':
dataset = PandQ_dataset(questions, qids_to_relevant_passageids, passages,
tokenizer, q_maxlength=args.text_max_len, p_maxlength=args.pas_max_len)
elif padding_mode == 'block':
print("padding block is under realization")
pass
else:
return NotImplementedError
print("example of questions id lists: ", dataset[50][0])
print("example of passage id lists: ", dataset[50][1])
print("total query dataset len :", len(dataset))
return dataset
class Text_Hidden_dataset(Dataset):
def __init__(self, query_dataset, hidden_state_set):
self.query_dataset = query_dataset
self.hidden_state_set = hidden_state_set
def __getitem__(self, index):
example = self.query_dataset[index]
hidden_state = self.hidden_state_set[index]
return example[0], example[1], hidden_state
'''
query_ids : len(np list) = batch_size query
input_ids : (batch_size, seq_len)
attention_mask : (batch_size, seq_len)
hidden_state : (batch_size, seq_len, embedding)
'''
@classmethod
def get_collate_fn(cls):
def fn(features):
id_list = [feature[0] for feature in features]
q_tensor = torch.cat([feature[1] for feature in features])
hidden_state = torch.cat([feature[2] for feature in features])
return {"query_ids":np.array(id_list), "input_ids":q_tensor,
"attention_mask":(q_tensor != 0).long(), "hidden_state":hidden_state}
return fn
class PandQ_dataset(Dataset):
def __init__(self, questions, qids_to_relevant_passageids, passages,
tokenizer, q_maxlength=32, p_maxlength=144):
self.questions = questions
self.tokenizer = tokenizer
self.q_maxlength = q_maxlength
self.p_maxlength = p_maxlength
self.qids_to_relevant_passageids = qids_to_relevant_passageids
self.passages = passages
def __getitem__(self, index):
q_example = self.questions[index]
query_id = q_example[0]
q_input_ids = self.tokenizer.encode(q_example[1], add_special_tokens=True,
max_length=self.q_maxlength, truncation=True,
padding='max_length',return_tensors='pt')
rel_passage_id = self.qids_to_relevant_passageids[query_id][0]
passage_example = self.passages[rel_passage_id]
text = passage_example[0]
title = passage_example[1]
p_input_ids = self.tokenizer.encode(title, text_pair=text, add_special_tokens=True,
max_length=self.p_maxlength, truncation=True,
padding='max_length', return_tensors='pt')
return q_input_ids, p_input_ids
def __len__(self):
return len(self.questions)
@classmethod
def get_collate_fn(cls):
def fn(features):
q_tensor = torch.cat([feature[0] for feature in features])
p_tensor = torch.cat([feature[1] for feature in features])
return { "q_input_ids": q_tensor, "q_attention_mask": (q_tensor != 0).long(),
"p_input_ids": p_tensor, "p_attention_mask": (p_tensor != 0).long() }
return fn
class Question_dataset(Dataset):
def __init__(self, questions, tokenizer,maxlength=32):
self.questions = questions
self.tokenizer = tokenizer
self.maxlength = maxlength
def __getitem__(self, index):
example = self.questions[index]
input_ids = self.tokenizer.encode(example[1], add_special_tokens=True,
max_length=self.maxlength, truncation=True,
padding='max_length',return_tensors='pt')
return example[0], input_ids
def __len__(self):
return len(self.questions)
@classmethod
def get_collate_fn(cls):
def fn(features):
id_list = [feature[0] for feature in features]
q_tensor = torch.cat([feature[1] for feature in features])
return np.array(id_list), q_tensor, (q_tensor != 0).long()
return fn
def load_dev_reference_from_stream(path_to_reference):
"""Load Reference reference relevant passages
Args:f (stream): stream to load.
Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints).
"""
qids_to_relevant_passageids = {}
with open(path_to_reference, 'r') as f:
for l in f:
try:
l = l.strip().split('\t')
qid = int(l[0])
if qid in qids_to_relevant_passageids:
pass
else:
qids_to_relevant_passageids[qid] = []
qids_to_relevant_passageids[qid].append(int(l[1]))
except:
raise IOError('\"%s\" is not valid format' % l)
return qids_to_relevant_passageids
def csv_reader(fd, delimiter='\t', trainer_id=0, trainer_num=1):
def gen():
for i, line in tqdm(enumerate(fd)):
if i % trainer_num == trainer_id:
slots = line.rstrip('\n').split(delimiter)
if len(slots) == 1:
yield slots,
else:
yield slots
return gen()
def load_train_reference_from_stream(input_file, trainer_id=0, trainer_num=1):
"""Reads a tab separated value file."""
with open(input_file, 'r', encoding='utf8') as f:
reader = csv_reader(f, trainer_id=trainer_id, trainer_num=trainer_num)
#headers = 'query_id\tpos_id\tneg_id'.split('\t')
#Example = namedtuple('Example', headers)
qrel = {}
for [topicid, _, docid, rel] in reader:
topicid = int(topicid)
assert rel == "1"
if topicid in qrel:
qrel[topicid].append(int(docid))
else:
qrel[topicid] = [int(docid)]
return qrel
def load_id_text(file_name):
"""load tsv files"""
id_text = {}
with open(file_name) as inp:
for line in tqdm(inp):
line = line.strip()
id, text = line.split('\t')
id_text[id] = text
return id_text

Просмотреть файл

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,157 @@
from abc import ABC, abstractmethod
import numpy as np
import torch as th
import torch.distributed as dist
def create_named_schedule_sampler(name, diffusion):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: the name of the sampler.
:param diffusion: the diffusion object to sample for.
"""
if name == "uniform":
print("using uniform time schedule sampler")
return UniformSampler(diffusion)
elif name == "loss-second-moment":
print("using loss-second-moment time schedule sampler")
return LossSecondMomentResampler(diffusion)
else:
raise NotImplementedError(f"unknown schedule sampler: {name}")
class ScheduleSampler(ABC):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights()
# print(w)
p = w / np.sum(w)
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
indices = th.from_numpy(indices_np).long().to(device)
weights_np = 1 / (len(p) * p[indices_np])
weights = th.from_numpy(weights_np).float().to(device)
return indices, weights
class UniformSampler(ScheduleSampler):
def __init__(self, diffusion):
self.diffusion = diffusion
self._weights = np.ones([diffusion.num_timesteps])
def weights(self):
return self._weights
class LossAwareSampler(ScheduleSampler):
def update_with_local_losses(self, local_ts, local_losses):
"""
Update the reweighting using losses from a model.
Call this method from each rank with a batch of timesteps and the
corresponding losses for each of those timesteps.
This method will perform synchronization to make sure all of the ranks
maintain the exact same reweighting.
:param local_ts: an integer Tensor of timesteps.
:param local_losses: a 1D Tensor of losses.
"""
batch_sizes = [
th.tensor([0], dtype=th.int32, device=local_ts.device)
for _ in range(dist.get_world_size())
]
dist.all_gather(
batch_sizes,
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
)
# Pad all_gather batches to be the maximum batch size.
batch_sizes = [x.item() for x in batch_sizes]
max_bs = max(batch_sizes)
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
dist.all_gather(timestep_batches, local_ts)
dist.all_gather(loss_batches, local_losses)
timesteps = [
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
]
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
self.update_with_all_losses(timesteps, losses)
@abstractmethod
def update_with_all_losses(self, ts, losses):
"""
Update the reweighting using losses from a model.
Sub-classes should override this method to update the reweighting
using losses from the model.
This method directly updates the reweighting without synchronizing
between workers. It is called by update_with_local_losses from all
ranks with identical arguments. Thus, it should have deterministic
behavior to maintain state across workers.
:param ts: a list of int timesteps.
:param losses: a list of float losses, one per timestep.
"""
class LossSecondMomentResampler(LossAwareSampler):
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
self.diffusion = diffusion
self.history_per_term = history_per_term
self.uniform_prob = uniform_prob
self._loss_history = np.zeros(
[diffusion.num_timesteps, history_per_term], dtype=np.float64
)
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
def weights(self):
if not self._warmed_up():
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
weights /= np.sum(weights)
weights *= 1 - self.uniform_prob
weights += self.uniform_prob / len(weights)
return weights
def update_with_all_losses(self, ts, losses):
for t, loss in zip(ts, losses):
if self._loss_counts[t] == self.history_per_term:
# Shift out the oldest loss term.
self._loss_history[t, :-1] = self._loss_history[t, 1:]
self._loss_history[t, -1] = loss
else:
self._loss_history[t, self._loss_counts[t]] = loss
self._loss_counts[t] += 1
def _warmed_up(self):
return (self._loss_counts == self.history_per_term).all()

Просмотреть файл

@ -0,0 +1,133 @@
import numpy as np
import torch as th
from diffusion_util.gaussian_diffusion import GaussianDiffusion
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim") :])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"])
# print(kwargs.keys())
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
# 根据时间采样表中存在的 t 计算一个新的 beta作用不明
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs)
def p_mean_variance(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
# print('called p_mean_var')
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def training_losses(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
# print('called training_losses')
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
return t
class _WrappedModel:
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
self.model = model
self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
# print(ts)
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts]
# print(new_ts)
if self.rescale_timesteps:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
# temp = self.model(x, new_ts, **kwargs)
# print(temp.shape)
# return temp
# print(new_ts)
return self.model(x, new_ts, **kwargs)

Двоичные данные
GENIE/image/GENIE.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 187 KiB

Просмотреть файл

Просмотреть файл

@ -0,0 +1,269 @@
import os
import argparse
from tqdm import tqdm
import string
import rouge
def get_arguments():
parser = argparse.ArgumentParser()
# out path
parser.add_argument('--generate_path', type=str, default='', help='output path')
parser.add_argument('--num_samples', type=int, default=50, help='sample query')
# data args
parser.add_argument('--data_path', type=str, default='', help='data path')
parser.add_argument('--data_name', type=str, default='', help='data name')
parser.add_argument('--batch_size', type=int, default=64, help='')
# seed
parser.add_argument('--seed', type=int, default=101, help='')
parser.add_argument('--n_gpu', type=int, default=4, help='')
args = parser.parse_args()
return args
_tok_dict = {"(": "-lrb-", ")": "-rrb-",
"[": "-lsb-", "]": "-rsb-",
"{": "-lcb-", "}": "-rcb-",
"[UNK]": "UNK", '&': '&amp;', '<': '&lt;', '>': '&gt;'}
def _is_digit(w):
for ch in w:
if not(ch.isdigit() or ch == ','):
return False
return True
def fix_tokenization(text):
input_tokens = text.split()
output_tokens = []
has_left_quote = False
has_left_single_quote = False
i = 0
prev_dash = False
while i < len(input_tokens):
tok = input_tokens[i]
flag_prev_dash = False
if tok in _tok_dict.keys():
output_tokens.append(_tok_dict[tok])
i += 1
elif tok == "\"":
if has_left_quote:
output_tokens.append("''")
else:
output_tokens.append("``")
has_left_quote = not has_left_quote
i += 1
elif tok == "'" and len(output_tokens) > 0 and output_tokens[-1].endswith("n") and i < len(input_tokens) - 1 and input_tokens[i + 1] == "t":
output_tokens[-1] = output_tokens[-1][:-1]
output_tokens.append("n't")
i += 2
elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[i + 1] in ("s", "d", "ll"):
output_tokens.append("'"+input_tokens[i + 1])
i += 2
elif tok == "'":
if has_left_single_quote:
output_tokens.append("'")
else:
output_tokens.append("`")
has_left_single_quote = not has_left_single_quote
i += 1
elif tok == "." and i < len(input_tokens) - 2 and input_tokens[i + 1] == "." and input_tokens[i + 2] == ".":
output_tokens.append("...")
i += 3
elif tok == "," and len(output_tokens) > 0 and _is_digit(output_tokens[-1]) and i < len(input_tokens) - 1 and _is_digit(input_tokens[i + 1]):
# $ 3 , 000 -> $ 3,000
output_tokens[-1] += ','+input_tokens[i + 1]
i += 2
elif tok == "." and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and input_tokens[i + 1].isdigit():
# 3 . 03 -> $ 3.03
output_tokens[-1] += '.'+input_tokens[i + 1]
i += 2
elif tok == "." and len(output_tokens) > 0 and len(output_tokens[-1]) == 1 and output_tokens[-1].isupper() and i < len(input_tokens) - 2 and len(input_tokens[i + 1]) == 1 and input_tokens[i + 1].isupper() and input_tokens[i + 2] == '.':
# U . N . -> U.N.
k = i+3
while k+2 < len(input_tokens):
if len(input_tokens[k + 1]) == 1 and input_tokens[k + 1].isupper() and input_tokens[k + 2] == '.':
k += 2
else:
break
output_tokens[-1] += ''.join(input_tokens[i:k])
i += 2
elif tok == "-":
if i < len(input_tokens) - 1 and input_tokens[i + 1] == "-":
output_tokens.append("--")
i += 2
elif i == len(input_tokens) - 1 or i == 0:
output_tokens.append("-")
i += 1
elif output_tokens[-1] not in string.punctuation and input_tokens[i + 1][0] not in string.punctuation:
output_tokens[-1] += "-"
i += 1
flag_prev_dash = True
else:
output_tokens.append("-")
i += 1
elif prev_dash and len(output_tokens) > 0 and tok[0] not in string.punctuation:
output_tokens[-1] += tok
i += 1
else:
output_tokens.append(tok)
i += 1
prev_dash = flag_prev_dash
return " ".join(output_tokens)
def process_eval(args, gen_list, tgt_list):
evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], max_n=2,
limit_length=False, apply_avg=True, weight_factor=1.2)
max_score = []
avg_score = []
lowest_score = []
best_gen_list = []
# max scoreavg scorelowest score
for index, sentence in enumerate(tqdm(gen_list)):
if index % args.num_samples == 0:
max_score_dict = {'rouge_1':0.0,'rouge_2':0.0,'rouge_l':0.0}
avg_score_dict = {'rouge_1': 0.0, 'rouge_2': 0.0, 'rouge_l': 0.0}
low_score_dict = {'rouge_1': 1.0, 'rouge_2': 1.0, 'rouge_l': 1.0}
target = tgt_list[index // args.num_samples]
scores = evaluator.get_scores([sentence], [[target]])
rouge_1 = scores['rouge-1']['f']
rouge_2 = scores['rouge-2']['f']
rouge_l = scores['rouge-l']['f']
# max
if rouge_2 >= max_score_dict['rouge_2']:
if rouge_2 != 0:
max_score_dict['rouge_1'] = rouge_1
max_score_dict['rouge_2'] = rouge_2
max_score_dict['rouge_l'] = rouge_l
best_sentence = sentence
else:
if rouge_1 >= max_score_dict['rouge_1']:
max_score_dict['rouge_1'] = rouge_1
max_score_dict['rouge_2'] = rouge_2
max_score_dict['rouge_l'] = rouge_l
best_sentence = sentence
# avg
avg_score_dict['rouge_1'] += rouge_1
avg_score_dict['rouge_2'] += rouge_2
avg_score_dict['rouge_l'] += rouge_l
# min
if rouge_2 < low_score_dict['rouge_2']:
low_score_dict['rouge_1'] = rouge_1
low_score_dict['rouge_2'] = rouge_2
low_score_dict['rouge_l'] = rouge_l
if (index + 1) % args.num_samples == 0:
max_score.append(max_score_dict)
best_gen_list.append(best_sentence)
avg_score_dict['rouge_1'] = avg_score_dict['rouge_1'] / args.num_samples
avg_score_dict['rouge_2'] = avg_score_dict['rouge_2'] / args.num_samples
avg_score_dict['rouge_l'] = avg_score_dict['rouge_l'] / args.num_samples
avg_score.append(avg_score_dict)
lowest_score.append(low_score_dict)
rouge_1 = 0
rouge_2 = 0
rouge_l = 0
for score_dict in max_score:
rouge_1 += score_dict['rouge_1']
rouge_2 += score_dict['rouge_2']
rouge_l += score_dict['rouge_l']
max_rouge_1 = rouge_1 / len(max_score)
max_rouge_2 = rouge_2 / len(max_score)
max_rouge_l = rouge_l / len(max_score)
rouge_1 = 0
rouge_2 = 0
rouge_l = 0
for score_dict in avg_score:
rouge_1 += score_dict['rouge_1']
rouge_2 += score_dict['rouge_2']
rouge_l += score_dict['rouge_l']
avg_rouge_1 = rouge_1 / len(max_score)
avg_rouge_2 = rouge_2 / len(max_score)
avg_rouge_l = rouge_l / len(max_score)
rouge_1 = 0
rouge_2 = 0
rouge_l = 0
for score_dict in lowest_score:
rouge_1 += score_dict['rouge_1']
rouge_2 += score_dict['rouge_2']
rouge_l += score_dict['rouge_l']
min_rouge_1 = rouge_1 / len(max_score)
min_rouge_2 = rouge_2 / len(max_score)
min_rouge_l = rouge_l / len(max_score)
scores = {'max_rouge_1':max_rouge_1, 'max_rouge_2':max_rouge_2, 'max_rouge_l':max_rouge_l,
'min_rouge_1':min_rouge_1, 'min_rouge_2':min_rouge_2, 'min_rouge_l':min_rouge_l,
'avg_rouge_1':avg_rouge_1, 'avg_rouge_2':avg_rouge_2, 'avg_rouge_l':avg_rouge_l }
return scores, best_gen_list
def main():
args = get_arguments()
tgt = []
test_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/test.tgt")
with open(test_tgt_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
# text = fix_tokenization(line)
tgt.append(text)
final_scores = {'max_rouge_1': 0.0, 'max_rouge_2': 0.0, 'max_rouge_l': 0.0,
'min_rouge_1': 0.0, 'min_rouge_2': 0.0, 'min_rouge_l': 0.0,
'avg_rouge_1': 0.0, 'avg_rouge_2': 0.0, 'avg_rouge_l': 0.0}
tgt_offset = 0
best_gen_list_total = []
for i in range(args.n_gpu):
gen_list_total = []
for epoch in range(args.num_samples):
gen_list = []
gen_text_path = os.path.join(args.generate_path, "rank" + str(i)+"_gen_seed_" + str(args.seed) +
"_num" + str(args.num_samples) + "_epoch" + str(epoch+1) + ".txt")
with open(gen_text_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip()
text = line
gen_list.append(text)
gen_list_total.append(gen_list)
gen_peace = []
for index in range(len(gen_list_total[0])):
for i in range(args.num_samples):
gen_peace.append(gen_list_total[i][index])
gen_len = len(gen_peace) // args.num_samples
scores, best_gen_list = process_eval(args, gen_peace, tgt[tgt_offset:tgt_offset+gen_len])
best_gen_list_total.extend(best_gen_list)
for key, values in scores.items():
final_scores[key] += values
print("scores on gpu ", i)
print(scores)
tgt_offset += gen_len
for key, values in final_scores.items():
final_scores[key] = values / args.n_gpu
print("final score :")
print(final_scores)
assert len(best_gen_list_total) == len(tgt)
print("store best gen list ...")
out_path = os.path.join(args.generate_path, "best_gen_list.txt")
with open(out_path, 'w') as f:
for sentence in best_gen_list_total:
f.write(sentence + '\n')
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,277 @@
import os
import argparse
from tqdm import tqdm
import string
import rouge
def get_arguments():
parser = argparse.ArgumentParser()
# out path
parser.add_argument('--generate_path', type=str, default='', help='output path')
parser.add_argument('--num_samples', type=int, default=50, help='sample query')
# data args
parser.add_argument('--data_path', type=str, default='', help='data path')
parser.add_argument('--data_name', type=str, default='', help='data name')
parser.add_argument('--batch_size', type=int, default=64, help='')
# seed
parser.add_argument('--seed', type=int, default=101, help='')
parser.add_argument('--n_gpu', type=int, default=4, help='')
args = parser.parse_args()
return args
_tok_dict = {"(": "-lrb-", ")": "-rrb-",
"[": "-lsb-", "]": "-rsb-",
"{": "-lcb-", "}": "-rcb-",
"[UNK]": "UNK", '&': '&amp;', '<': '&lt;', '>': '&gt;'}
def _is_digit(w):
for ch in w:
if not(ch.isdigit() or ch == ','):
return False
return True
def fix_tokenization(text):
input_tokens = text.split()
output_tokens = []
has_left_quote = False
has_left_single_quote = False
i = 0
prev_dash = False
while i < len(input_tokens):
tok = input_tokens[i]
flag_prev_dash = False
if tok in _tok_dict.keys():
output_tokens.append(_tok_dict[tok])
i += 1
elif tok == "\"":
if has_left_quote:
output_tokens.append("''")
else:
output_tokens.append("``")
has_left_quote = not has_left_quote
i += 1
elif tok == "'" and len(output_tokens) > 0 and output_tokens[-1].endswith("n") and i < len(input_tokens) - 1 and input_tokens[i + 1] == "t":
output_tokens[-1] = output_tokens[-1][:-1]
output_tokens.append("n't")
i += 2
elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[i + 1] in ("s", "d", "ll"):
output_tokens.append("'"+input_tokens[i + 1])
i += 2
elif tok == "'":
if has_left_single_quote:
output_tokens.append("'")
else:
output_tokens.append("`")
has_left_single_quote = not has_left_single_quote
i += 1
elif tok == "." and i < len(input_tokens) - 2 and input_tokens[i + 1] == "." and input_tokens[i + 2] == ".":
output_tokens.append("...")
i += 3
elif tok == "," and len(output_tokens) > 0 and _is_digit(output_tokens[-1]) and i < len(input_tokens) - 1 and _is_digit(input_tokens[i + 1]):
# $ 3 , 000 -> $ 3,000
output_tokens[-1] += ','+input_tokens[i + 1]
i += 2
elif tok == "." and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and input_tokens[i + 1].isdigit():
# 3 . 03 -> $ 3.03
output_tokens[-1] += '.'+input_tokens[i + 1]
i += 2
elif tok == "." and len(output_tokens) > 0 and len(output_tokens[-1]) == 1 and output_tokens[-1].isupper() and i < len(input_tokens) - 2 and len(input_tokens[i + 1]) == 1 and input_tokens[i + 1].isupper() and input_tokens[i + 2] == '.':
# U . N . -> U.N.
k = i+3
while k+2 < len(input_tokens):
if len(input_tokens[k + 1]) == 1 and input_tokens[k + 1].isupper() and input_tokens[k + 2] == '.':
k += 2
else:
break
output_tokens[-1] += ''.join(input_tokens[i:k])
i += 2
elif tok == "-":
if i < len(input_tokens) - 1 and input_tokens[i + 1] == "-":
output_tokens.append("--")
i += 2
elif i == len(input_tokens) - 1 or i == 0:
output_tokens.append("-")
i += 1
elif output_tokens[-1] not in string.punctuation and input_tokens[i + 1][0] not in string.punctuation:
output_tokens[-1] += "-"
i += 1
flag_prev_dash = True
else:
output_tokens.append("-")
i += 1
elif prev_dash and len(output_tokens) > 0 and tok[0] not in string.punctuation:
output_tokens[-1] += tok
i += 1
else:
output_tokens.append(tok)
i += 1
prev_dash = flag_prev_dash
return " ".join(output_tokens)
def process_eval(args, gen_list, tgt_list, ori_gen_peace):
evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], max_n=2,
limit_length=False, apply_avg=True, weight_factor=1.2)
max_score = []
avg_score = []
lowest_score = []
best_gen_list = []
# max scoreavg scorelowest score
for index, sentence in enumerate(tqdm(gen_list)):
if index % args.num_samples == 0:
max_score_dict = {'rouge_1':0.0,'rouge_2':0.0,'rouge_l':0.0}
avg_score_dict = {'rouge_1': 0.0, 'rouge_2': 0.0, 'rouge_l': 0.0}
low_score_dict = {'rouge_1': 1.0, 'rouge_2': 1.0, 'rouge_l': 1.0}
target = tgt_list[index // args.num_samples]
scores = evaluator.get_scores([sentence], [[target]])
rouge_1 = scores['rouge-1']['f']
rouge_2 = scores['rouge-2']['f']
rouge_l = scores['rouge-l']['f']
# max
if rouge_2 >= max_score_dict['rouge_2']:
if rouge_2 != 0:
max_score_dict['rouge_1'] = rouge_1
max_score_dict['rouge_2'] = rouge_2
max_score_dict['rouge_l'] = rouge_l
best_sentence = ori_gen_peace[index]
else:
if rouge_1 >= max_score_dict['rouge_1']:
max_score_dict['rouge_1'] = rouge_1
max_score_dict['rouge_2'] = rouge_2
max_score_dict['rouge_l'] = rouge_l
best_sentence = ori_gen_peace[index]
# avg
avg_score_dict['rouge_1'] += rouge_1
avg_score_dict['rouge_2'] += rouge_2
avg_score_dict['rouge_l'] += rouge_l
# min
if rouge_2 < low_score_dict['rouge_2']:
low_score_dict['rouge_1'] = rouge_1
low_score_dict['rouge_2'] = rouge_2
low_score_dict['rouge_l'] = rouge_l
if (index + 1) % args.num_samples == 0:
max_score.append(max_score_dict)
best_gen_list.append(best_sentence)
avg_score_dict['rouge_1'] = avg_score_dict['rouge_1'] / args.num_samples
avg_score_dict['rouge_2'] = avg_score_dict['rouge_2'] / args.num_samples
avg_score_dict['rouge_l'] = avg_score_dict['rouge_l'] / args.num_samples
avg_score.append(avg_score_dict)
lowest_score.append(low_score_dict)
rouge_1 = 0
rouge_2 = 0
rouge_l = 0
for score_dict in max_score:
rouge_1 += score_dict['rouge_1']
rouge_2 += score_dict['rouge_2']
rouge_l += score_dict['rouge_l']
max_rouge_1 = rouge_1 / len(max_score)
max_rouge_2 = rouge_2 / len(max_score)
max_rouge_l = rouge_l / len(max_score)
rouge_1 = 0
rouge_2 = 0
rouge_l = 0
for score_dict in avg_score:
rouge_1 += score_dict['rouge_1']
rouge_2 += score_dict['rouge_2']
rouge_l += score_dict['rouge_l']
avg_rouge_1 = rouge_1 / len(max_score)
avg_rouge_2 = rouge_2 / len(max_score)
avg_rouge_l = rouge_l / len(max_score)
rouge_1 = 0
rouge_2 = 0
rouge_l = 0
for score_dict in lowest_score:
rouge_1 += score_dict['rouge_1']
rouge_2 += score_dict['rouge_2']
rouge_l += score_dict['rouge_l']
min_rouge_1 = rouge_1 / len(max_score)
min_rouge_2 = rouge_2 / len(max_score)
min_rouge_l = rouge_l / len(max_score)
scores = {'max_rouge_1':max_rouge_1, 'max_rouge_2':max_rouge_2, 'max_rouge_l':max_rouge_l,
'min_rouge_1':min_rouge_1, 'min_rouge_2':min_rouge_2, 'min_rouge_l':min_rouge_l,
'avg_rouge_1':avg_rouge_1, 'avg_rouge_2':avg_rouge_2, 'avg_rouge_l':avg_rouge_l }
return scores, best_gen_list
def main():
args = get_arguments()
tgt = []
test_tgt_path = os.path.join(args.data_path, args.data_name + "/org_data/test.tgt")
with open(test_tgt_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
line = line.strip().replace(" <S_SEP> ", '\n')
text = line
tgt.append(text)
final_scores = {'max_rouge_1': 0.0, 'max_rouge_2': 0.0, 'max_rouge_l': 0.0,
'min_rouge_1': 0.0, 'min_rouge_2': 0.0, 'min_rouge_l': 0.0,
'avg_rouge_1': 0.0, 'avg_rouge_2': 0.0, 'avg_rouge_l': 0.0}
tgt_offset = 0
best_gen_list_total = []
for i in range(args.n_gpu):
gen_list_total = []
ori_gen_list_total = []
for epoch in range(args.num_samples):
gen_list = []
ori_gen_list = []
gen_text_path = os.path.join(args.generate_path, "rank" + str(i)+"_gen_seed_" + str(args.seed) +
"_num" + str(args.num_samples) + "_epoch" + str(epoch+1) + ".txt")
with open(gen_text_path, "r", encoding="utf-8") as ifile:
for line in tqdm(ifile):
ori_gen_list.append(line.strip())
line = line.strip().split(' < s _ sep > ')
line = [fix_tokenization(sen) for sen in line]
text = "\n".join(line)
# text = line
gen_list.append(text)
gen_list_total.append(gen_list)
ori_gen_list_total.append(ori_gen_list)
gen_peace = []
for index in range(len(gen_list_total[0])):
for i in range(args.num_samples):
gen_peace.append(gen_list_total[i][index])
ori_gen_peace = []
for index in range(len(ori_gen_list_total[0])):
for i in range(args.num_samples):
ori_gen_peace.append(ori_gen_list_total[i][index])
gen_len = len(gen_peace) // args.num_samples
scores, best_gen_list = process_eval(args, gen_peace, tgt[tgt_offset:tgt_offset+gen_len], ori_gen_peace)
best_gen_list_total.extend(best_gen_list)
for key, values in scores.items():
final_scores[key] += values
print("scores on gpu ", i)
print(scores)
tgt_offset += gen_len
for key, values in final_scores.items():
final_scores[key] = values / args.n_gpu
print("final score :")
print(final_scores)
assert len(best_gen_list_total) == len(tgt)
print("store best gen list ...")
out_path = os.path.join(args.generate_path, "best_gen_list.txt")
with open(out_path, 'w') as f:
for sentence in best_gen_list_total:
f.write(sentence + '\n')
if __name__ == "__main__":
main()

Просмотреть файл

@ -0,0 +1,372 @@
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
) # is self-attn if context is none
# layer norms
self.use_ada_layer_norm = num_embeds_ada_norm is not None
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size
# def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
# if not is_xformers_available():
# print("Here is how to install it")
# raise ModuleNotFoundError(
# "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
# " xformers",
# name="xformers",
# )
# elif not torch.cuda.is_available():
# raise ValueError(
# "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
# " available for GPU "
# )
# else:
# try:
# # Make sure we can run the memory efficient attention
# _ = xformers.ops.memory_efficient_attention(
# torch.randn((1, 2, 40), device="cuda"),
# torch.randn((1, 2, 40), device="cuda"),
# torch.randn((1, 2, 40), device="cuda"),
# )
# except Exception as e:
# raise e
# self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
# self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
'''
input value
hidden_state: encode from query (batch_size, query_seq_len, hidden_size)
context: encode from passage (batch_size, passage_seq_len, hidden_size)
output value
hidden_states : (batch_size, query_seq_len, hidden_size)
'''
def forward(self, hidden_states, context=None):
# 1. Self-Attention
norm_hidden_states = (
self.norm1(hidden_states)
)
hidden_states = self.attn1(norm_hidden_states) + hidden_states
# 2. Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states)
)
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
# 3. Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states
class CrossAttention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the context. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.scale = dim_head**-0.5
self.heads = heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self._slice_size = None
self._use_memory_efficient_attention_xformers = False
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
key = self.to_k(context)
value = self.to_v(context)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def _attention(self, query, key, value):
# TODO: use baddbmm for better performance
if query.device.type == "mps":
# Better performance on mps (~20-25%)
attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
else:
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
if query.device.type == "mps":
hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
else:
hidden_states = torch.matmul(attention_probs, value)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _sliced_attention(self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
if query.device.type == "mps":
# Better performance on mps (~20-25%)
attn_slice = (
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
* self.scale
)
else:
attn_slice = (
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
attn_slice = attn_slice.softmax(dim=-1)
if query.device.type == "mps":
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
else:
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
# def _memory_efficient_attention_xformers(self, query, key, value):
# query = query.contiguous()
# key = key.contiguous()
# value = value.contiguous()
# hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
# return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "geglu":
geglu = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
geglu = ApproximateGELU(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(geglu)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out))
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
# feedforward
class GEGLU(nn.Module):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
class ApproximateGELU(nn.Module):
"""
The approximate form of Gaussian Error Linear Unit (GELU)
For more details, see section 2: https://arxiv.org/abs/1606.08415
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
def forward(self, x):
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
class AdaLayerNorm(nn.Module):
"""
Norm layer modified to incorporate timestep embeddings.
"""
def __init__(self, embedding_dim, num_embeddings):
super().__init__()
self.emb = nn.Embedding(num_embeddings, embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
def forward(self, x, timestep):
emb = self.linear(self.silu(self.emb(timestep)))
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x

372
GENIE/model/Diffusion_LM.py Normal file
Просмотреть файл

@ -0,0 +1,372 @@
import torch
from torch import nn
import numpy as np
import math
from transformers import (
BertModel,
BertConfig,
)
from model.CrossAttentionTransformers import BasicTransformerBlock
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class Diffusion_LM(nn.Module):
def __init__(
self,
in_channels,
model_channels,
out_channels,
dropout=0,
config=None,
config_name='bert-base-uncased',
vocab_size=None,
init_pretrained=True,
logits_mode=1,
token_emb_type='pretrain',
# num_heads=1,
# channel_mult=(1, 2, 4, 8),
# use_scale_shift_norm=False,
# training_mode='emb',
# experiment_mode='lm',
# num_heads_upsample=-1,
# use_checkpoint=False,
# num_classes=None,
# conv_resample=True,
# attention_resolutions,
# num_res_blocks,
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.dropout = dropout
self.logits_mode = logits_mode
self.init_pretrained = init_pretrained
self.token_emb_type = token_emb_type
config = BertConfig.from_pretrained(config_name)
config.hidden_dropout_prob = self.dropout
print(config)
# 可训练的 embedding 层
self.word_embedding = nn.Embedding(vocab_size, self.in_channels)
# position embedding
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
if self.token_emb_type == 'pretrain':
temp_bert = BertModel.from_pretrained(config_name, config=config)
self.word_embedding.weight = temp_bert.embeddings.word_embeddings.weight
self.position_embeddings.weight = temp_bert.embeddings.position_embeddings.weight
elif self.token_emb_type == 'random':
print("load embedding weight random")
else:
return NotImplementedError
if self.logits_mode == 2:
# self.lm_head = nn.Linear(self.in_channels, vocab_size, bias=False)
self.lm_head = nn.Linear(self.in_channels, vocab_size, bias=True)
else:
self.lm_head = nn.Linear(self.in_channels, vocab_size)
# share weight between lm_head and word_embedding
with torch.no_grad():
self.lm_head.weight = self.word_embedding.weight
# self.word_embedding = nn.Embedding(vocab_size, self.in_channels)
# self.lm_head = nn.Linear(self.in_channels, vocab_size)
# with th.no_grad():
# self.lm_head.weight = self.word_embedding.weight
# time embedding
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, config.hidden_size),
)
# position embedding
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# # label embedding
# if self.num_classes is not None:
# self.label_emb = nn.Embedding(num_classes, time_embed_dim)
# input transform
self.input_up_proj = nn.Sequential(
nn.Linear(in_channels, config.hidden_size),
nn.Tanh(),
nn.Linear(config.hidden_size, config.hidden_size)
)
# Dropout
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
'''
Diffusion Transformer (6 layer)
'''
if self.init_pretrained:
temp_bert = BertModel.from_pretrained(config_name, config=config)
del temp_bert.embeddings
del temp_bert.pooler
self.input_transformers = temp_bert.encoder
print('initializing from pretrained bert.')
else:
temp_bert = BertModel(config)
self.input_transformers = temp_bert.encoder
print('initializing from random bert.')
# output transform
self.output_down_proj = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.Tanh(),
nn.Linear(config.hidden_size, out_channels)
)
def get_embeds(self, input_ids):
return self.word_embedding(input_ids)
def get_logits(self, hidden_repr):
if self.logits_mode == 1:
return self.lm_head(hidden_repr)
elif self.logits_mode == 2:
text_emb = hidden_repr
emb_norm = (self.lm_head.weight ** 2).sum(-1).view(-1, 1) # vocab
text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(self.lm_head.weight,
text_emb_t) # (vocab, d) x (d, bsz*seqlen)
scores = torch.sqrt(torch.clamp(dist, 0.0, np.inf)).view(emb_norm.size(0), hidden_repr.size(0),
hidden_repr.size(1)) # vocab, bsz*seqlen
scores = -scores.permute(1, 2, 0).contiguous()
#
# scores1 = th.cdist(self.lm_head.weight.unsqueeze(0), hidden_repr, p=2)
# scores1 = -scores1.permute(0, 2, 1).contiguous()
#
# print(scores1.shape, scores.shape)
# print(scores1[0,0], scores[0,0])
# print(torch.isclose(scores1, scores))
return scores
else:
raise NotImplementedError
def forward(self, x, timesteps, attention_mask=None, y=None, src_ids=None, src_mask=None):
# prepare embedding
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb_x = self.input_up_proj(x)
seq_length = x.size(1)
position_ids = self.position_ids[:, : seq_length]
# print(emb_x.shape, emb.shape, self.position_embeddings)
emb_inputs = self.position_embeddings(position_ids) + emb_x + emb.unsqueeze(1).expand(-1, seq_length, -1)
emb_inputs = self.dropout(self.LayerNorm(emb_inputs))
# encode embedding
# print(emb_inputs.shape, attention_mask.shape)
input_trans_hidden_states = self.input_transformers(emb_inputs, attention_mask=attention_mask).last_hidden_state
h = self.output_down_proj(input_trans_hidden_states)
h = h.type(x.dtype)
return h
class CrossAttention_Diffusion_LM(nn.Module):
def __init__(
self,
in_channels,
model_channels,
out_channels,
dropout=0,
config=None,
config_name='bert-base-uncased',
vocab_size=None,
init_pretrained=True,
logits_mode=1,
token_emb_type='pretrain',
fix_encoder=False
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.dropout = dropout
self.logits_mode = logits_mode
self.init_pretrained = init_pretrained
self.token_emb_type = token_emb_type
self.fix_encoder = fix_encoder
cfg = BertConfig.from_pretrained(config_name)
cfg.num_hidden_layers = 6
self.passage_encoder = BertModel.from_pretrained(config_name, config=cfg)
# self.passage_encoder = BertModel.from_pretrained(
# "/colab_space/Lin0/PROD/KDexp/pretrain_model/bert-base-uncased", config=cfg)
config = BertConfig.from_pretrained(config_name)
config.hidden_dropout_prob = self.dropout
print(config)
# trainable embedding layer
self.word_embedding = nn.Embedding(vocab_size, self.in_channels)
# position embedding
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
if self.logits_mode == 2:
# self.lm_head = nn.Linear(self.in_channels, vocab_size, bias=False)
self.lm_head = nn.Linear(self.in_channels, vocab_size, bias=True)
else:
self.lm_head = nn.Linear(self.in_channels, vocab_size)
# share weight between lm_head and word_embedding
with torch.no_grad():
self.lm_head.weight = self.word_embedding.weight
# self.word_embedding = nn.Embedding(vocab_size, self.in_channels)
# self.lm_head = nn.Linear(self.in_channels, vocab_size)
# with th.no_grad():
# self.lm_head.weight = self.word_embedding.weight
# time embedding layer
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(model_channels, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, config.hidden_size),
)
# position embedding
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# # label embedding
# if self.num_classes is not None:
# self.label_emb = nn.Embedding(num_classes, time_embed_dim)
# input transform
self.input_up_proj = nn.Sequential(
nn.Linear(in_channels, config.hidden_size),
nn.Tanh(),
nn.Linear(config.hidden_size, config.hidden_size)
)
# Dropout
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
config.num_hidden_layers = 6
# define cross attention transformer block(6 layer)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=config.hidden_size,
num_attention_heads=config.num_attention_heads,
attention_head_dim=config.hidden_size // config.num_attention_heads,
dropout=config.hidden_dropout_prob,
cross_attention_dim=config.hidden_size,
activation_fn="geglu",
)
for d in range(config.num_hidden_layers)
]
)
# output transform
self.output_down_proj = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.Tanh(),
nn.Linear(config.hidden_size, out_channels)
)
def get_embeds(self, input_ids):
return self.word_embedding(input_ids)
def get_logits(self, hidden_repr):
if self.logits_mode == 1:
return self.lm_head(hidden_repr)
elif self.logits_mode == 2:
text_emb = hidden_repr
emb_norm = (self.lm_head.weight ** 2).sum(-1).view(-1, 1) # vocab
text_emb_t = torch.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) # d, bsz*seqlen
arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) # bsz*seqlen, 1
dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * torch.mm(self.lm_head.weight,
text_emb_t) # (vocab, d) x (d, bsz*seqlen)
scores = torch.sqrt(torch.clamp(dist, 0.0, np.inf)).view(emb_norm.size(0), hidden_repr.size(0),
hidden_repr.size(1)) # vocab, bsz*seqlen
scores = -scores.permute(1, 2, 0).contiguous()
return scores
else:
raise NotImplementedError
def forward(self, x, timesteps, src_input_ids, src_attention_mask, attention_mask=None,
answer_id=None, answer_mask=None, y=None, src_ids=None, src_mask=None):
# prepare embedding
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
emb_x = self.input_up_proj(x)
seq_length = x.size(1)
position_ids = self.position_ids[:, : seq_length]
# print(emb_x.shape, emb.shape, self.position_embeddings)
emb_inputs = self.position_embeddings(position_ids) + emb_x + emb.unsqueeze(1).expand(-1, seq_length, -1)
hidden_states = self.dropout(self.LayerNorm(emb_inputs))
# encode embedding
# print(emb_inputs.shape, attention_mask.shape)
if self.fix_encoder:
with torch.no_grad():
out = self.passage_encoder(input_ids=src_input_ids,
attention_mask=src_attention_mask)
passage_hidden = out.last_hidden_state
else:
out = self.passage_encoder(input_ids=src_input_ids,
attention_mask=src_attention_mask)
passage_hidden = out.last_hidden_state + 0 * out.pooler_output.unsqueeze(1)
if answer_id is not None:
answer_hidden_states = hidden_states.clone()
answer_out = self.passage_encoder(input_ids=answer_id,
attention_mask=answer_mask)
answer_hidden = answer_out.last_hidden_state + 0 * answer_out.pooler_output.unsqueeze(1)
for block in self.transformer_blocks:
answer_hidden_states = block(answer_hidden_states, answer_hidden)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, passage_hidden)
if answer_id is not None:
# print("model_qg_forward...")
hidden_states = hidden_states + answer_hidden_states
h = self.output_down_proj(hidden_states)
h = h.type(x.dtype)
return h

0
GENIE/model/__init__.py Normal file
Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -0,0 +1,82 @@
"""
Helpers for distributed training.
"""
import io
import os
import socket
import blobfile as bf
from mpi4py import MPI
import torch as th
import torch.distributed as dist
# Change this to reflect your cluster layout.
# The GPU for a given rank is (rank % GPUS_PER_NODE).
GPUS_PER_NODE = 1 #8
SETUP_RETRY_COUNT = 3
def setup_dist():
"""
Setup a distributed process group.
"""
if dist.is_initialized():
return
comm = MPI.COMM_WORLD
backend = "gloo" if not th.cuda.is_available() else "nccl"
if backend == "gloo":
hostname = "localhost"
else:
hostname = socket.gethostbyname(socket.getfqdn())
os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
os.environ["RANK"] = str(comm.rank)
os.environ["WORLD_SIZE"] = str(comm.size)
port = comm.bcast(_find_free_port(), root=0)
os.environ["MASTER_PORT"] = str(port)
dist.init_process_group(backend=backend, init_method="env://")
def dev():
"""
Get the device to use for torch.distributed.
"""
if th.cuda.is_available():
return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
return th.device("cpu")
def load_state_dict(path, **kwargs):
"""
Load a PyTorch file without redundant fetches across MPI ranks.
"""
if MPI.COMM_WORLD.Get_rank() == 0:
with bf.BlobFile(path, "rb") as f:
data = f.read()
else:
data = None
data = MPI.COMM_WORLD.bcast(data)
return th.load(io.BytesIO(data), **kwargs)
def sync_params(params):
"""
Synchronize a sequence of Tensors across ranks from rank 0.
"""
for p in params:
with th.no_grad():
dist.broadcast(p, 0)
def _find_free_port():
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
finally:
s.close()

Просмотреть файл

@ -0,0 +1,375 @@
import torch
import copy
import os
from torch import nn
import collections
from util import logger
from train_util import dist_util
from transformers import AdamW
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.serialization import default_restore_location
from tqdm import tqdm
import numpy as np
from transformers import (
get_linear_schedule_with_warmup,
)
from data_util.pretrain_data_util import load_loop_pretrain_data
# from text_data_util import Text_Hidden_dataset, Question_dataset, PandQ_dataset
# from s2s_data_util import S2S_dataset
from data_util.pretrain_data_util import Pre_dataset, Pre_dataset_type2
from diffusion_util.resample import LossAwareSampler, UniformSampler
INITIAL_LOG_LOSS_SCALE = 20.0
CheckpointState = collections.namedtuple("CheckpointState",
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset'])
'''
TrainLoop
'''
class PretrainLoop:
def __init__(
self,
train_type,
model,
diffusion,
data,
batch_size,
lr,
ema_rate,
log_interval,
save_interval,
resume_checkpoint,
warmup_steps=0,
use_fp16=False,
fp16_scale_growth=1e-3,
schedule_sampler=None,
weight_decay=0.0,
lr_anneal_steps=0,
checkpoint_path='',
gradient_clipping=-1.,
eval_data=None,
eval_interval=-1,
gradient_accumulation_steps=1,
device=None,
args=None,
tokenizer=None
):
self.train_type = train_type
self.model = model
self.diffusion = diffusion
self.data = data
self.eval_data = eval_data
self.batch_size = batch_size
self.lr = lr
self.ema_rate = (
[ema_rate]
if isinstance(ema_rate, float)
else [float(x) for x in ema_rate.split(",")]
)
self.log_interval = log_interval
self.eval_interval = eval_interval
self.save_interval = save_interval
self.resume_checkpoint = resume_checkpoint
self.warmup_steps = warmup_steps
self.use_fp16 = use_fp16
self.fp16_scale_growth = fp16_scale_growth
self.schedule_sampler = schedule_sampler
self.weight_decay = weight_decay
self.lr_anneal_steps = lr_anneal_steps
self.gradient_clipping = gradient_clipping
self.gradient_accumulation_steps = gradient_accumulation_steps
self.device = device
self.args = args
self.tokenizer = tokenizer
# self.global_batch = self.batch_size * dist.get_world_size()
self.master_params = list(self.model.parameters())
self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
# self.sync_cuda = th.cuda.is_available()
self.checkpoint_path = checkpoint_path
self.optimizer = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
self.total_step = 5000000
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.total_step
)
self.global_step = 0
if self.checkpoint_path is not None:
model_checkpoint_files = []
ema_checkpoint_files = []
if os.path.exists(self.checkpoint_path):
for item in os.scandir(self.checkpoint_path):
if item.is_file():
if "model_checkpoint" in item.path:
model_checkpoint_files.append(item.path)
if "ema" in item.path:
ema_checkpoint_files.append(item.path)
if len(model_checkpoint_files) != 0 and len(ema_checkpoint_files) != 0:
model_checkpoint_files.sort(key=lambda f: int(f.split('model_checkpoint-')[1]), reverse=True)
logger.info("***** load " + model_checkpoint_files[0] + " *****")
ema_checkpoint_files.sort(key=lambda f: int(f.split('checkpoint-')[-1]), reverse=True)
logger.info("***** load " + ema_checkpoint_files[0] + " *****")
model_saved_state = load_states_from_checkpoint(model_checkpoint_files[0])
self.global_step = self._load_saved_state(model_saved_state)
self.ema_params = [
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
]
else:
self.ema_params = [
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
]
logger.info("***** there are no checkpoint in" + self.checkpoint_path + " *****")
else:
self.ema_params = [
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
]
# model to DDP
if dist.get_world_size() > 1:
self.model = DDP(
self.model, device_ids=[dist.get_rank()], output_device=dist.get_rank(), find_unused_parameters=False,
)
else:
print("single GPU is not achieve now")
exit(0)
def run_loop(self):
logger.info("***** Running training *****")
logger.info(" Max steps = %d", self.total_step)
logger.info(" Instantaneous batch size per GPU = %d", self.batch_size)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
self.batch_size
* self.gradient_accumulation_steps
* (dist.get_world_size()),
)
logger.info(" Gradient Accumulation steps = %d", self.gradient_accumulation_steps)
self.model.zero_grad()
self.model.train()
while self.global_step < self.total_step:
for data_name in self.data:
print("pretraining diffusion using data :", data_name)
train_data = load_loop_pretrain_data(
self.args,
padding_mode='conti_tgt',
tokenizer=self.tokenizer,
data_name=data_name,
)
# ddp data sample
train_sample = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sample, batch_size=self.batch_size, drop_last=False,
num_workers=20, collate_fn=Pre_dataset_type2.get_collate_fn())
# while self.global_step < self.lr_anneal_steps:
# training for one epoch
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=dist.get_rank() not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
self.model.train()
# forward loss
self.forward_backward(batch)
if self.use_fp16:
pass
else:
if (step + 1) % self.gradient_accumulation_steps == 0:
# gradient clip
if self.gradient_clipping > 0:
self.grad_clip()
self._log_grad_norm()
self.optimizer.step()
# lr scheduler
# self._anneal_lr()
self.scheduler.step()
self.model.zero_grad()
for rate, params in zip(self.ema_rate, self.ema_params):
self.update_ema(params, self.master_params, rate=rate)
self.global_step += 1
self.log_step()
if (self.global_step + 1) % self.log_interval == 0:
logger.dumpkvs()
# dist.barrier()
if (self.global_step + 1) % self.save_interval == 0:
self.save()
# self.save()
if self.global_step > self.total_step:
break
def save(self):
def save_checkpoint(rate, ema_params):
model_to_save = get_model_obj(self.model)
if not rate:
model_state_dict = model_to_save.state_dict()
else:
model_state_dict = model_to_save.state_dict()
for i, (name, _value) in enumerate(model_to_save.named_parameters()):
assert name in model_state_dict
model_state_dict[name] = ema_params[i]
opt_state_dict = self.optimizer.state_dict()
sch_state_dict = self.scheduler.state_dict()
offset = self.global_step
state = CheckpointState(model_state_dict,
opt_state_dict,
sch_state_dict,
offset,
)
if not rate:
ckpt_path = os.path.join(self.checkpoint_path, 'model_checkpoint-' + str(offset))
else:
ckpt_path = os.path.join(self.checkpoint_path, 'ema_' + str(rate) + '_checkpoint-' + str(offset))
torch.save(state._asdict(), ckpt_path)
logger.info('Saved checkpoint at %s', ckpt_path)
if dist.get_rank() == 0:
save_checkpoint(0, None)
for rate, params in zip(self.ema_rate, self.ema_params):
save_checkpoint(rate, params)
# dist.barrier()
def forward_backward(self, batch):
t, weights = self.schedule_sampler.sample(batch['src_input_ids'].shape[0], self.device)
# print("src_input_ids shape:", batch['src_input_ids'].shape)
# print("tgt_input_ids shape:", batch['tgt_input_ids'].shape)
losses = self.diffusion.training_losses(self.model, batch, t)
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(
t, losses["loss"].detach()
)
loss = (losses["loss"] * weights).mean()
if self.gradient_accumulation_steps > 1:
loss = loss / self.gradient_accumulation_steps
log_loss_dict(
self.diffusion, t, {k: v * weights for k, v in losses.items()}
)
if self.use_fp16:
loss_scale = 2 ** self.lg_loss_scale
(loss * loss_scale).backward()
else:
loss.backward()
def forward_only(self, batch):
with torch.no_grad():
self.model.zero_grad()
if self.train_type == 'LM_Diffusion':
t, weights = self.schedule_sampler.sample(batch[1].shape[0], dist_util.dev())
inputs_text = {"query_ids": batch[1].long().to(self.device),
"attention_mask_q": batch[2].long().to(self.device)}
losses = self.diffusion.training_losses(self.model, inputs_text, t)
elif self.train_type == 'S2S_Diffusion':
'''
for s2s
'''
t, weights = self.schedule_sampler.sample(batch['src_input_ids'].shape[0], self.device)
losses = self.diffusion.training_losses(self.model, batch, t)
else:
return NotImplementedError
log_loss_dict(
self.diffusion, t, {f"eval_{k}": v * weights for k, v in losses.items()}
)
def _log_grad_norm(self):
sqsum = 0.0
for p in self.master_params:
# print(p)
sqsum += (p.grad ** 2).sum().item()
logger.logkv_mean("grad_norm", np.sqrt(sqsum))
def log_step(self):
logger.logkv("step", self.global_step)
if self.use_fp16:
logger.logkv("lg_loss_scale", self.lg_loss_scale)
def _anneal_lr(self):
if not self.lr_anneal_steps:
return
frac_done = self.global_step / self.lr_anneal_steps
lr = self.lr * (1 - frac_done)
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
def grad_clip(self):
# print('doing gradient clipping')
max_grad_norm=self.gradient_clipping #3.0
if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(max_grad_norm)
else:
# Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), #amp.master_params(self.opt) if self.use_apex else
max_grad_norm,
)
def _load_saved_state(self, saved_state: CheckpointState):
self.global_step = saved_state.offset
logger.info('Loading checkpoint @ step=%s', self.global_step)
logger.info('Loading saved model state ...')
self.model.load_state_dict(saved_state.model_dict) # set strict=False if you use extra projection
self.optimizer.load_state_dict(saved_state.optimizer_dict)
self.scheduler.load_state_dict(saved_state.scheduler_dict)
self.master_params = list(self.model.parameters())
return self.global_step
def update_ema(self, target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
# print("target_params:", targ.device)
# print("source_params:", src.device)
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def log_loss_dict(diffusion, ts, losses):
for key, values in losses.items():
logger.logkv_mean(key, values.mean().item())
# Log the quantiles (four quartiles, in particular).
for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
quartile = int(4 * sub_t / diffusion.num_timesteps)
logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
def get_model_obj(model: nn.Module):
return model.module if hasattr(model, 'module') else model
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
logger.info('Reading saved model from %s', model_file)
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
logger.info('model_state_dict keys %s', state_dict.keys())
return CheckpointState(**state_dict)

Просмотреть файл

@ -0,0 +1,396 @@
import torch
import copy
import os
from torch import nn
import collections
from util import logger
from train_util import dist_util
from transformers import AdamW
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
import numpy as np
from torch.serialization import default_restore_location
from transformers import (
get_linear_schedule_with_warmup,
)
from transformers import (
BertModel,
BertConfig,
)
from diffusion_util.resample import LossAwareSampler, UniformSampler
from data_util.text_data_util import Text_Hidden_dataset, Question_dataset, PandQ_dataset
from data_util.s2s_data_util import S2S_dataset, QG_dataset_Diff
INITIAL_LOG_LOSS_SCALE = 20.0
CheckpointState = collections.namedtuple("CheckpointState",
['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset'])
'''
TrainLoop training class
'''
class TrainLoop:
def __init__(
self,
train_type,
model,
diffusion,
data,
batch_size,
lr,
ema_rate,
log_interval,
save_interval,
resume_checkpoint,
warmup_steps=0,
use_fp16=False,
fp16_scale_growth=1e-3,
schedule_sampler=None,
weight_decay=0.0,
lr_anneal_steps=0,
checkpoint_path='',
gradient_clipping=-1.,
eval_data=None,
eval_interval=-1,
gradient_accumulation_steps=1,
device=None,
data_name="xsum_data",
):
self.train_type = train_type
self.model = model
self.diffusion = diffusion
self.data = data
self.eval_data = eval_data
self.batch_size = batch_size
self.lr = lr
self.ema_rate = (
[ema_rate]
if isinstance(ema_rate, float)
else [float(x) for x in ema_rate.split(",")]
)
self.log_interval = log_interval
self.eval_interval = eval_interval
self.save_interval = save_interval
self.resume_checkpoint = resume_checkpoint
self.warmup_steps = warmup_steps
self.use_fp16 = use_fp16
self.fp16_scale_growth = fp16_scale_growth
self.schedule_sampler = schedule_sampler
self.weight_decay = weight_decay
self.lr_anneal_steps = lr_anneal_steps
self.gradient_clipping = gradient_clipping
self.gradient_accumulation_steps = gradient_accumulation_steps
self.device = device
self.data_name = data_name
self.master_params = list(self.model.parameters())
self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
self.checkpoint_path = checkpoint_path
self.optimizer = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.lr_anneal_steps
)
self.global_step = 0
# load last checkpoint
if self.checkpoint_path is not None:
model_checkpoint_files = []
ema_checkpoint_files = []
if os.path.exists(self.checkpoint_path):
for item in os.scandir(self.checkpoint_path):
if item.is_file():
if "model_checkpoint" in item.path:
model_checkpoint_files.append(item.path)
if "ema" in item.path:
ema_checkpoint_files.append(item.path)
if len(model_checkpoint_files) != 0 and len(ema_checkpoint_files) != 0:
model_checkpoint_files.sort(key=lambda f: int(f.split('model_checkpoint-')[1]), reverse=True)
logger.info("***** load " + model_checkpoint_files[0] + " *****")
ema_checkpoint_files.sort(key=lambda f: int(f.split('checkpoint-')[-1]), reverse=True)
logger.info("***** load " + ema_checkpoint_files[0] + " *****")
model_saved_state = load_states_from_checkpoint(model_checkpoint_files[0])
self.global_step = self._load_saved_state(model_saved_state)
self.ema_params = [
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
]
else:
self.ema_params = [
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
]
logger.info("***** there are no checkpoint in" + self.checkpoint_path + " *****")
else:
self.ema_params = [
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
]
# model to DDP
if dist.get_world_size() > 1:
self.model = DDP(
self.model, device_ids=[dist.get_rank()], output_device=dist.get_rank(), find_unused_parameters=False,
)
else:
print("single GPU is not achieve now")
exit(0)
def run_loop(self):
logger.info("***** Running training *****")
logger.info(" Max steps = %d", self.lr_anneal_steps)
logger.info(" Instantaneous batch size per GPU = %d", self.batch_size)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
self.batch_size
* self.gradient_accumulation_steps
* (dist.get_world_size()),
)
logger.info(" Gradient Accumulation steps = %d", self.gradient_accumulation_steps)
self.model.zero_grad()
self.model.train()
# ddp data sample
if self.train_type == 'LM_Diffusion':
train_sample = DistributedSampler(self.data)
train_dataloader = DataLoader(self.data, sampler=train_sample, batch_size=self.batch_size, drop_last=False,
num_workers=20, collate_fn=Question_dataset.get_collate_fn())
elif self.train_type == 'S2S_Diffusion':
train_sample = DistributedSampler(self.data)
'''
for s2s
'''
train_dataloader = DataLoader(self.data, sampler=train_sample, batch_size=self.batch_size, drop_last=False,
num_workers=20, collate_fn=S2S_dataset.get_collate_fn())
else:
return NotImplementedError
while self.global_step < self.lr_anneal_steps:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=dist.get_rank() not in [-1, 0])
for batch in epoch_iterator:
self.model.train()
# forward loss
self.forward_backward(batch)
if self.use_fp16:
pass
else:
# gradient clip
if self.gradient_clipping > 0:
self.grad_clip()
self._log_grad_norm()
self.optimizer.step()
# lr scheduler
self.scheduler.step()
self.model.zero_grad()
# ema
for rate, params in zip(self.ema_rate, self.ema_params):
self.update_ema(params, self.master_params, rate=rate)
self.global_step += 1
self.log_step()
if self.global_step % self.log_interval == 0:
logger.dumpkvs()
if self.eval_data is not None and self.global_step % self.eval_interval == 0:
if dist.get_rank() == 0:
print('eval on validation set...')
if self.train_type == 'LM_Diffusion':
dev_dataloader = DataLoader(self.eval_data, batch_size=self.batch_size,
drop_last=False,
num_workers=20, collate_fn=Question_dataset.get_collate_fn())
elif self.train_type == 'S2S_Diffusion':
'''
for s2s
'''
dev_dataloader = DataLoader(self.eval_data, batch_size=self.batch_size,
drop_last=False,
num_workers=20, collate_fn=S2S_dataset.get_collate_fn())
else:
return NotImplementedError
for step, batch in enumerate(dev_dataloader):
self.forward_only(batch)
if step > 10:
break
logger.dumpkvs()
# save
if self.global_step % self.save_interval == 0:
self.save()
def save(self):
def save_checkpoint(rate, ema_params):
model_to_save = get_model_obj(self.model)
if not rate:
model_state_dict = model_to_save.state_dict()
else:
model_state_dict = model_to_save.state_dict()
for i, (name, _value) in enumerate(model_to_save.named_parameters()):
assert name in model_state_dict
model_state_dict[name] = ema_params[i]
opt_state_dict = self.optimizer.state_dict()
sch_state_dict = self.scheduler.state_dict()
offset = self.global_step
state = CheckpointState(model_state_dict,
opt_state_dict,
sch_state_dict,
offset,
)
if not rate:
ckpt_path = os.path.join(self.checkpoint_path, 'model_checkpoint-' + str(offset))
else:
ckpt_path = os.path.join(self.checkpoint_path, 'ema_' + str(rate) + '_checkpoint-' + str(offset))
torch.save(state._asdict(), ckpt_path)
logger.info('Saved checkpoint at %s', ckpt_path)
if dist.get_rank() == 0:
save_checkpoint(0, None)
for rate, params in zip(self.ema_rate, self.ema_params):
save_checkpoint(rate, params)
def forward_backward(self, batch):
if self.train_type == 'LM_Diffusion':
t, weights = self.schedule_sampler.sample(batch[1].shape[0], self.device)
inputs_text = {"query_ids": batch[1].long().to(self.device),
"attention_mask_q": batch[2].long().to(self.device)}
losses = self.diffusion.training_losses(self.model, inputs_text, t)
elif self.train_type == 'S2S_Diffusion':
'''
for s2s
'''
t, weights = self.schedule_sampler.sample(batch['src_input_ids'].shape[0], self.device)
losses = self.diffusion.training_losses(self.model, batch, t)
else:
return NotImplementedError
if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(
t, losses["loss"].detach()
)
loss = (losses["loss"] * weights).mean()
log_loss_dict(
self.diffusion, t, {k: v * weights for k, v in losses.items()}
)
if self.use_fp16:
loss_scale = 2 ** self.lg_loss_scale
(loss * loss_scale).backward()
else:
loss.backward()
def forward_only(self, batch):
with torch.no_grad():
self.model.zero_grad()
if self.train_type == 'LM_Diffusion':
t, weights = self.schedule_sampler.sample(batch[1].shape[0], dist_util.dev())
inputs_text = {"query_ids": batch[1].long().to(self.device),
"attention_mask_q": batch[2].long().to(self.device)}
losses = self.diffusion.training_losses(self.model, inputs_text, t)
elif self.train_type == 'S2S_Diffusion':
'''
for s2s
'''
t, weights = self.schedule_sampler.sample(batch['src_input_ids'].shape[0], self.device)
losses = self.diffusion.training_losses(self.model, batch, t)
else:
return NotImplementedError
log_loss_dict(
self.diffusion, t, {f"eval_{k}": v * weights for k, v in losses.items()}
)
def _log_grad_norm(self):
sqsum = 0.0
for p in self.master_params:
# print(p)
sqsum += (p.grad ** 2).sum().item()
logger.logkv_mean("grad_norm", np.sqrt(sqsum))
def log_step(self):
logger.logkv("step", self.global_step)
if self.use_fp16:
logger.logkv("lg_loss_scale", self.lg_loss_scale)
def _anneal_lr(self):
if not self.lr_anneal_steps:
return
frac_done = self.global_step / self.lr_anneal_steps
lr = self.lr * (1 - frac_done)
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
def grad_clip(self):
# print('doing gradient clipping')
max_grad_norm=self.gradient_clipping #3.0
if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(max_grad_norm)
# else:
# assert False
# elif hasattr(self.model, "clip_grad_norm_"):
# # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
# self.model.clip_grad_norm_(args.max_grad_norm)
else:
# Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), #amp.master_params(self.opt) if self.use_apex else
max_grad_norm,
)
def _load_saved_state(self, saved_state: CheckpointState):
self.global_step = saved_state.offset
logger.info('Loading checkpoint @ step=%s', self.global_step)
logger.info('Loading saved model state ...')
self.model.load_state_dict(saved_state.model_dict) # set strict=False if you use extra projection
self.optimizer.load_state_dict(saved_state.optimizer_dict)
self.scheduler.load_state_dict(saved_state.scheduler_dict)
self.master_params = list(self.model.parameters())
return self.global_step
def update_ema(self, target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
# print("target_params:", targ.device)
# print("source_params:", src.device)
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def log_loss_dict(diffusion, ts, losses):
for key, values in losses.items():
logger.logkv_mean(key, values.mean().item())
# Log the quantiles (four quartiles, in particular).
for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
quartile = int(4 * sub_t / diffusion.num_timesteps)
logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
def get_model_obj(model: nn.Module):
return model.module if hasattr(model, 'module') else model
def load_states_from_checkpoint(model_file: str) -> CheckpointState:
logger.info('Reading saved model from %s', model_file)
state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu'))
logger.info('model_state_dict keys %s', state_dict.keys())
return CheckpointState(**state_dict)

0
GENIE/util/__init__.py Normal file
Просмотреть файл

498
GENIE/util/logger.py Normal file
Просмотреть файл

@ -0,0 +1,498 @@
"""
Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
"""
import os
import sys
import shutil
import os.path as osp
import json
import time
import datetime
import tempfile
import warnings
from collections import defaultdict
from contextlib import contextmanager
# import wandb
DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40
DISABLED = 50
class KVWriter(object):
def writekvs(self, kvs):
raise NotImplementedError
class SeqWriter(object):
def writeseq(self, seq):
raise NotImplementedError
class HumanOutputFormat(KVWriter, SeqWriter):
def __init__(self, filename_or_file):
if isinstance(filename_or_file, str):
self.file = open(filename_or_file, "wt")
self.own_file = True
else:
assert hasattr(filename_or_file, "read"), (
"expected file or str, got %s" % filename_or_file
)
self.file = filename_or_file
self.own_file = False
def writekvs(self, kvs):
# Create strings for printing
key2str = {}
for (key, val) in sorted(kvs.items()):
if hasattr(val, "__float__"):
valstr = "%-8.3g" % val
else:
valstr = str(val)
key2str[self._truncate(key)] = self._truncate(valstr)
# Find max widths
if len(key2str) == 0:
print("WARNING: tried to write empty key-value dict")
return
else:
keywidth = max(map(len, key2str.keys()))
valwidth = max(map(len, key2str.values()))
# Write out the data
dashes = "-" * (keywidth + valwidth + 7)
lines = [dashes]
for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
lines.append(
"| %s%s | %s%s |"
% (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
)
lines.append(dashes)
self.file.write("\n".join(lines) + "\n")
# Flush the output to the file
self.file.flush()
def _truncate(self, s):
maxlen = 30
return s[: maxlen - 3] + "..." if len(s) > maxlen else s
def writeseq(self, seq):
seq = list(seq)
for (i, elem) in enumerate(seq):
self.file.write(elem)
if i < len(seq) - 1: # add space unless this is the last one
self.file.write(" ")
self.file.write("\n")
self.file.flush()
def close(self):
if self.own_file:
self.file.close()
class JSONOutputFormat(KVWriter):
def __init__(self, filename):
self.file = open(filename, "wt")
def writekvs(self, kvs):
for k, v in sorted(kvs.items()):
if hasattr(v, "dtype"):
kvs[k] = float(v)
self.file.write(json.dumps(kvs) + "\n")
self.file.flush()
def close(self):
self.file.close()
class CSVOutputFormat(KVWriter):
def __init__(self, filename):
self.file = open(filename, "w+t")
self.keys = []
self.sep = ","
def writekvs(self, kvs):
# Add our current row to the history
extra_keys = list(kvs.keys() - self.keys)
extra_keys.sort()
if extra_keys:
self.keys.extend(extra_keys)
self.file.seek(0)
lines = self.file.readlines()
self.file.seek(0)
for (i, k) in enumerate(self.keys):
if i > 0:
self.file.write(",")
self.file.write(k)
self.file.write("\n")
for line in lines[1:]:
self.file.write(line[:-1])
self.file.write(self.sep * len(extra_keys))
self.file.write("\n")
for (i, k) in enumerate(self.keys):
if i > 0:
self.file.write(",")
v = kvs.get(k)
if v is not None:
self.file.write(str(v))
self.file.write("\n")
self.file.flush()
def close(self):
self.file.close()
class TensorBoardOutputFormat(KVWriter):
"""
Dumps key/value pairs into TensorBoard's numeric format.
"""
def __init__(self, dir):
os.makedirs(dir, exist_ok=True)
self.dir = dir
self.step = 1
prefix = "events"
path = osp.join(osp.abspath(dir), prefix)
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.core.util import event_pb2
from tensorflow.python.util import compat
self.tf = tf
self.event_pb2 = event_pb2
self.pywrap_tensorflow = pywrap_tensorflow
self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
def writekvs(self, kvs):
def summary_val(k, v):
kwargs = {"tag": k, "simple_value": float(v)}
return self.tf.Summary.Value(**kwargs)
summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
event.step = (
self.step
) # is there any reason why you'd want to specify the step?
self.writer.WriteEvent(event)
self.writer.Flush()
self.step += 1
def close(self):
if self.writer:
self.writer.Close()
self.writer = None
def make_output_format(format, ev_dir, log_suffix=""):
os.makedirs(ev_dir, exist_ok=True)
if format == "stdout":
return HumanOutputFormat(sys.stdout)
elif format == "log":
return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
elif format == "json":
return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
elif format == "csv":
return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
elif format == "tensorboard":
return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
else:
raise ValueError("Unknown format specified: %s" % (format,))
# ================================================================
# API
# ================================================================
def logkv(key, val):
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
If called many times, last value will be used.
"""
get_current().logkv(key, val)
def logkv_mean(key, val):
"""
The same as logkv(), but if called many times, values averaged.
"""
get_current().logkv_mean(key, val)
def logkvs(d):
"""
Log a dictionary of key-value pairs
"""
for (k, v) in d.items():
logkv(k, v)
def dumpkvs():
"""
Write all of the diagnostics from the current iteration
"""
return get_current().dumpkvs()
def getkvs():
return get_current().name2val
def log(*args, level=INFO):
"""
Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
"""
get_current().log(*args, level=level)
def debug(*args):
log(*args, level=DEBUG)
def info(*args):
log(*args, level=INFO)
def warn(*args):
log(*args, level=WARN)
def error(*args):
log(*args, level=ERROR)
def set_level(level):
"""
Set logging threshold on current logger.
"""
get_current().set_level(level)
def set_comm(comm):
get_current().set_comm(comm)
def get_dir():
"""
Get directory that log files are being written to.
will be None if there is no output directory (i.e., if you didn't call start)
"""
return get_current().get_dir()
record_tabular = logkv
dump_tabular = dumpkvs
@contextmanager
def profile_kv(scopename):
logkey = "wait_" + scopename
tstart = time.time()
try:
yield
finally:
get_current().name2val[logkey] += time.time() - tstart
def profile(n):
"""
Usage:
@profile("my_func")
def my_func(): code
"""
def decorator_with_name(func):
def func_wrapper(*args, **kwargs):
with profile_kv(n):
return func(*args, **kwargs)
return func_wrapper
return decorator_with_name
# ================================================================
# Backend
# ================================================================
def get_current():
if Logger.CURRENT is None:
_configure_default_logger()
return Logger.CURRENT
class Logger(object):
DEFAULT = None # A logger with no output files. (See right below class definition)
# So that you can still log to the terminal without setting up any output files
CURRENT = None # Current logger being used by the free functions above
def __init__(self, dir, output_formats, comm=None):
self.name2val = defaultdict(float) # values this iteration
self.name2cnt = defaultdict(int)
self.level = INFO
self.dir = dir
self.output_formats = output_formats
self.comm = comm
# Logging API, forwarded
# ----------------------------------------
def logkv(self, key, val):
self.name2val[key] = val
def logkv_mean(self, key, val):
oldval, cnt = self.name2val[key], self.name2cnt[key]
self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
self.name2cnt[key] = cnt + 1
def dumpkvs(self, prefix=None):
if self.comm is None:
d = self.name2val
else:
d = mpi_weighted_mean(
self.comm,
{
name: (val, self.name2cnt.get(name, 1))
for (name, val) in self.name2val.items()
},
)
if self.comm.rank != 0:
d["dummy"] = 1 # so we don't get a warning about empty dict
# LISA
# wandb.log({**d})
out = d.copy() # Return the dict for unit testing purposes
for fmt in self.output_formats:
if isinstance(fmt, KVWriter):
fmt.writekvs(d)
self.name2val.clear()
self.name2cnt.clear()
return out
def log(self, *args, level=INFO):
if self.level <= level:
self._do_log(args)
# Configuration
# ----------------------------------------
def set_level(self, level):
self.level = level
def set_comm(self, comm):
self.comm = comm
def get_dir(self):
return self.dir
def close(self):
for fmt in self.output_formats:
fmt.close()
# Misc
# ----------------------------------------
def _do_log(self, args):
for fmt in self.output_formats:
if isinstance(fmt, SeqWriter):
fmt.writeseq(map(str, args))
def get_rank_without_mpi_import():
# check environment variables here instead of importing mpi4py
# to avoid calling MPI_Init() when this module is imported
for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
if varname in os.environ:
return int(os.environ[varname])
return 0
def mpi_weighted_mean(comm, local_name2valcount):
"""
Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
Perform a weighted average over dicts that are each on a different node
Input: local_name2valcount: dict mapping key -> (value, count)
Returns: key -> mean
"""
all_name2valcount = comm.gather(local_name2valcount)
if comm.rank == 0:
name2sum = defaultdict(float)
name2count = defaultdict(float)
for n2vc in all_name2valcount:
for (name, (val, count)) in n2vc.items():
try:
val = float(val)
except ValueError:
if comm.rank == 0:
warnings.warn(
"WARNING: tried to compute mean on non-float {}={}".format(
name, val
)
)
else:
name2sum[name] += val * count
name2count[name] += count
return {name: name2sum[name] / name2count[name] for name in name2sum}
else:
return {}
def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
"""
If comm is provided, average all numerical stats across that comm
"""
if dir is None:
dir = os.getenv("OPENAI_LOGDIR")
if dir is None:
dir = osp.join(
tempfile.gettempdir(),
datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
)
assert isinstance(dir, str)
dir = os.path.expanduser(dir)
os.makedirs(os.path.expanduser(dir), exist_ok=True)
rank = get_rank_without_mpi_import()
if rank > 0:
log_suffix = log_suffix + "-rank%03i" % rank
if format_strs is None:
if rank == 0:
format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
else:
format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
format_strs = filter(None, format_strs)
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
if output_formats:
log("Logging to %s" % dir)
def _configure_default_logger():
configure()
Logger.DEFAULT = Logger.CURRENT
def reset():
if Logger.CURRENT is not Logger.DEFAULT:
Logger.CURRENT.close()
Logger.CURRENT = Logger.DEFAULT
log("Reset logger")
@contextmanager
def scoped_configure(dir=None, format_strs=None, comm=None):
prevlogger = Logger.CURRENT
configure(dir=dir, format_strs=format_strs, comm=comm)
try:
yield
finally:
Logger.CURRENT.close()
Logger.CURRENT = prevlogger

119
GENIE/util/losses.py Normal file
Просмотреть файл

@ -0,0 +1,119 @@
"""
Helpers for various likelihood-based losses. These are ported from the original
Ho et al. diffusion models codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
"""
import numpy as np
import torch as th
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, th.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
# print(logvar2.shape)
# temp1 = 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2))
# print(f'const = {temp1.mean()}, coef={(th.exp(-logvar2) * 0.5).mean()}, mse={((mean1 - mean2) ** 2).mean().item()}')
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ th.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
)
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = th.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = th.where(
x < -0.999,
log_cdf_plus,
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
def gaussian_density(x, *, means, log_scales):
from torch.distributions import Normal
normal_dist = Normal(means, log_scales.exp())
logp = normal_dist.log_prob(x)
return logp
def discretized_text_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
print(x.shape, means.shape)
# assert x.shape == means.shape == log_scales.shape
print(x, means)
centered_x = x - means
inv_stdv = th.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = th.where(
x < -0.999,
log_cdf_plus,
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs

170
GENIE/util/nn.py Normal file
Просмотреть файл

@ -0,0 +1,170 @@
"""
Various utilities for neural networks.
"""
import math
import torch as th
import torch.nn as nn
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * th.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = th.exp(
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(th.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with th.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with th.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = th.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads

158
GENIE/util/util.py Normal file
Просмотреть файл

@ -0,0 +1,158 @@
from model.Diffusion_LM import Diffusion_LM, CrossAttention_Diffusion_LM
from diffusion_util import gaussian_diffusion as gd
from diffusion_util.respace import SpacedDiffusion, space_timesteps
from diffusion_util.gaussian_diffusion import GaussianDiffusion
def create_model_and_diffusion(
args
):
model = create_model(
model_channels=args.model_channels,
learn_sigma=args.learn_sigma,
dropout=args.dropout,
model_arch=args.model_arch,
in_channel=args.in_channel,
out_channel=args.out_channel,
vocab_size=args.vocab_size,
config_name=args.config_name,
logits_mode=args.logits_mode,
init_pretrained=args.init_pretrained,
token_emb_type=args.token_emb_type,
)
diffusion = create_gaussian_diffusion(
steps=args.diffusion_steps,
learn_sigma=args.learn_sigma,
sigma_small=args.sigma_small,
noise_schedule=args.noise_schedule,
use_kl=args.use_kl,
predict_xstart=args.predict_xstart,
rescale_timesteps=args.rescale_timesteps,
rescale_learned_sigmas=args.rescale_learned_sigmas,
model_arch=args.model_arch,
training_mode=args.training_mode,
)
return model, diffusion
'''
create diffusion model
'''
def create_model(
model_channels,
learn_sigma,
dropout,
model_arch,
in_channel=8,
out_channel=8,
vocab_size=None,
config_name='',
logits_mode=1,
init_pretrained=True,
token_emb_type='pretrain',
):
print(f'creating model, based on {model_arch}')
if model_arch == 'transformer':
return Diffusion_LM(
in_channels=in_channel,
model_channels=model_channels,
out_channels=(out_channel if not learn_sigma else out_channel*2),
dropout=dropout,
config_name=config_name,
vocab_size=vocab_size,
logits_mode=logits_mode,
init_pretrained=init_pretrained,
token_emb_type=token_emb_type,
)
elif model_arch == 's2s_CAT':
return CrossAttention_Diffusion_LM(
in_channels=in_channel,
model_channels=model_channels,
out_channels=(out_channel if not learn_sigma else out_channel * 2),
dropout=dropout,
config_name=config_name,
vocab_size=vocab_size,
logits_mode=logits_mode,
init_pretrained=init_pretrained,
token_emb_type=token_emb_type,
)
else:
raise NotImplementedError
'''
create diffusion process
'''
def create_gaussian_diffusion(
steps=1000,
learn_sigma=False,
sigma_small=False,
noise_schedule="cosine",
use_kl=False,
predict_xstart=False,
rescale_timesteps=False,
rescale_learned_sigmas=False,
timestep_respacing="",
model_arch='transformer',
training_mode='e2e',
):
# β , Determine according to the maximum T and variance schedule
print("noise_schedule: ", noise_schedule)
print("Diffusion Steps: ", steps)
betas = gd.get_named_beta_schedule(noise_schedule, steps)
print("betas: ", betas)
# determine the loss function used in training
if training_mode == 'e2e' or training_mode == 's2s':
# end to end training
if use_kl:
loss_type = gd.LossType.E2E_KL
else:
loss_type = gd.LossType.E2E_MSE
elif training_mode == 'e2e-simple':
if use_kl:
loss_type = gd.LossType.E2E_Simple_KL
else:
loss_type = gd.LossType.E2E_Simple_MSE
else:
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if not timestep_respacing:
timestep_respacing = [steps]
print("Diffusion Loss Type: ", loss_type, " , Whether to learn sigma: ", learn_sigma)
print("Diffusion predict xstart: ", predict_xstart)
return GaussianDiffusion(
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
),
model_var_type=(
(
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type,
rescale_timesteps=rescale_timesteps,
model_arch=model_arch,
training_mode=training_mode,
)
def args_to_dict(args, keys):
return {k: getattr(args, k) for k in keys}

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Просмотреть файл

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше