WikiCommentEdit/train.py

149 строки
6.6 KiB
Python

import datetime
import glob
import os
import sys
import torch
import torch.nn.functional as F
from eval import eval
from process_data import to_var, make_vector, gen_cmntrank_batches, gen_editanch_batches
def rank_loss(score_pos, score_neg):
# normalize context attention
# ctx_att_norm = F.normalize(ctx_att, p=2, dim=1)
batch_size = len(score_pos)
# y = Variable(torch.FloatTensor([1] * batch_size))
margin = to_var(torch.FloatTensor(score_pos.size()).fill_(1))
# loss = F.margin_ranking_loss(score_pos, -score_neg, y, margin=10.0)
loss_list = margin - score_pos + score_neg
loss_list = loss_list.clamp(min=0)
loss = loss_list.sum()
return loss, loss_list
def cal_batch_loss(loss_list, batch_size, index_list):
loss_list = loss_list.data.squeeze(1).numpy().tolist()
loss_step = int(len(loss_list) / batch_size)
# return [sum(loss_list[i * loss_step: (i + 1) * loss_step]) / loss_step for i in range(batch_size)]
start = 0
batch_loss_list = []
for i in range(batch_size):
cur_bs = index_list[i]
end = start + cur_bs
if cur_bs == 0:
batch_loss_list.append(0)
else:
batch_loss_list.append(sum(loss_list[start:end]) / cur_bs)
start = end
return batch_loss_list
def train(args, model, dataset, train_df, val_df, optimizer, w2i, n_epoch, start_epoch, batch_size):
print('----Train---')
label = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
model.train()
cmnt_sent_len = args.max_cmnt_length
diff_sent_len = args.max_diff_length
ctx_sent_len = args.max_ctx_length
train_size = len(train_df)
# if args.use_cl:
# sample_size = int(train_size * 0.5)
# else:
# sample_size = -1
sample_size = int(train_size)
cr_best_acc, ea_best_acc = 0, 0
for epoch in range(start_epoch, n_epoch + 1):
print('============================== Epoch ', epoch, ' ==============================')
for step, batch in enumerate(dataset.iterate_minibatches(train_df, batch_size, epoch, n_epoch), start=1):
batch_sample_weights = None
total_loss = 0
if args.cr_train:
cr_pos_batch, cr_neg_batch = gen_cmntrank_batches(batch, w2i, cmnt_sent_len, diff_sent_len,
ctx_sent_len, args.rank_num)
if len(cr_pos_batch[0]) > 0:
# TODO: tuning the code for more effective way: combine the positive sample and negative samples
pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action = \
make_vector(cr_pos_batch, w2i, cmnt_sent_len, ctx_sent_len)
neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action = \
make_vector(cr_neg_batch, w2i, cmnt_sent_len, ctx_sent_len)
score_pos, _ = model(pos_cmnt, pos_src_token, pos_src_action, pos_tgt_token, pos_tgt_action,
cr_mode=True)
score_neg, _ = model(neg_cmnt, neg_src_token, neg_src_action, neg_tgt_token, neg_tgt_action,
cr_mode=True)
loss, _ = rank_loss(score_pos, score_neg)
# batch_sample_weights = cal_batch_loss(loss_list, batch_size, index_list)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# train revision anchoring
if args.ea_train:
ea_batch, ea_truth = gen_editanch_batches(batch, w2i, cmnt_sent_len, diff_sent_len, ctx_sent_len,
args.anchor_num)
if len(ea_batch[0]) > 0:
cmnt, src_token, src_action, tgt_token, tgt_action = \
make_vector(ea_batch, w2i, cmnt_sent_len, ctx_sent_len)
logit, _ = model(cmnt, src_token, src_action, tgt_token, tgt_action, cr_mode=False)
target = to_var(torch.tensor(ea_truth))
loss = F.cross_entropy(logit, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update the loss for each data sample
# new_sample_weights += batch_sample_weights
if step % args.log_interval == 0:
# corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
# p1_corr, p3_corr, p5_corr, mrr, ndcg = eval_rank(score_pos, score_neg, args.batch_size)
# p1_acc = p1_corr / args.batch_size * 100
# p3_acc = p3_corr / args.batch_size * 100
# p5_acc = p5_corr / args.batch_size * 100
try:
sys.stdout.write(
'\rEpoch[{}] Batch[{}] - loss: {:.6f}\n'.format(epoch, step, loss.data.item()))
sys.stdout.flush()
except:
print("Unexpected error:", sys.exc_info()[0])
if step % args.test_interval == 0:
if args.val_size > 0:
val_df = val_df[:args.val_size]
cr_acc, ea_acc = eval(dataset, val_df, w2i, model, args)
model.train() # change model back to training mode
if args.cr_train and cr_acc > cr_best_acc:
cr_best_acc = cr_acc
if args.save_best:
save(model, args.save_dir, 'best_cr', epoch, step, cr_best_acc, args.no_action)
if args.ea_train and ea_acc > ea_best_acc:
ea_best_acc = ea_acc
if args.save_best:
save(model, args.save_dir, 'best_ea', epoch, step, ea_best_acc, args.no_action)
# sample_weights = new_sample_weights
def save(model, save_dir, save_prefix, epoch, steps, best_result, no_action=False):
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
save_prefix = os.path.join(save_dir, save_prefix)
# delete previously saved checkpoints
exist_files = sorted(glob.glob(save_prefix + '*'))
for file_name in exist_files:
if os.path.exists(file_name):
os.remove(file_name)
result_str = '%.3f' % best_result
save_path = '{}_steps_{:02}_{:06}_{}.pt'.format(save_prefix, epoch, steps, result_str)
print("Save best model", save_path)
torch.save(model.state_dict(), save_path)