MWSS/train.py

502 строки
21 KiB
Python

'''
Copyright (c) Microsoft Corporation, Yichuan Li and Kai Shu.
Licensed under the MIT license.
Authors: Guoqing Zheng (zheng@microsoft.com), Yichuan Li and Kai Shu
'''
import os
import sys
import torch
import torch.autograd as autograd
import torch.nn.functional as F
from torch.distributions.bernoulli import Bernoulli
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score
from itertools import chain
from correction_matrix import correction_result, get_correction_matrix
import model
from itertools import chain
def train(gold_iter, sliver_iter, val_iter, model, args, C_hat=None, statues=""):
if args.cuda:
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
steps = 0
best_acc = 0
last_step = 0
model.train()
gold_batch_count_1 = len(gold_iter)
sliver_batch_count = len(sliver_iter)
sliver_time = sliver_batch_count / gold_batch_count_1
gold_batch_count = int(sliver_time) * gold_batch_count_1
gold_iter_list = [gold_iter for i in range(int(sliver_time))]
gold_iter_list.append(sliver_iter)
for epoch in range(1, args.epochs+1):
sliver_gt_label = []
sliver_target_label = []
sliver_predic_pro = []
for batch_idx, batch in enumerate(chain(gold_iter_list)):
model.train()
feature, target = batch.text, batch.label
feature = torch.transpose(feature, 1, 0)
target = target - 1
if args.cuda:
feature, target = feature.cuda(), target.cuda()
optimizer.zero_grad()
logit = model(feature)
if batch_idx >= gold_batch_count and C_hat is not None:
# switch to the sliver mode
sliver_gt_label.append((batch.gt_label.numpy() - 1).tolist())
logit = correction_result(logit, C_hat)
sliver_target_label.append(target.cpu().numpy().tolist())
sliver_predic_pro.append(torch.argmax(logit, dim=-1).cpu().numpy().tolist())
logit = torch.log(logit)
th1 = target[target > 1]
th2 = target[target < 0]
assert len(th1) == 0 and len(th2) == 0
loss = F.nll_loss(logit, target)
loss.backward()
optimizer.step()
if steps % args.log_interval == 0:
corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
accuracy = 100.0 * corrects/batch.batch_size
# sys.stdout.write(
# '\r{} Batch[{}] - loss: {:.6f} acc: {:.4f}%({}/{})'.format(statues, steps,
# loss.data,
# accuracy,
# corrects,
# batch.batch_size))
if steps % args.test_interval == 0:
# if steps % 1== 0:
dev_acc = eval(val_iter, model, args)
dev_acc = dev_acc[0]
if dev_acc > best_acc:
best_acc = dev_acc
last_step = steps
# save the best model
if args.save_best:
save(model, args.save_dir, 'best_{}'.format(statues), 0)
else:
if steps - last_step >= args.early_stop:
print('early stop by {} steps.'.format(args.early_stop))
elif steps % args.save_interval == 0:
save(model, args.save_dir, 'snapshot', steps)
steps += 1
if C_hat is not None:
sliver_gt_label = list(chain.from_iterable(sliver_gt_label))
sliver_target_label = list(chain.from_iterable(sliver_target_label))
sliver_predic_pro = list(chain.from_iterable(sliver_predic_pro))
acc = accuracy_score(sliver_gt_label, sliver_target_label)
precision = precision_score(sliver_gt_label, sliver_target_label, average="macro")
recall = recall_score(sliver_gt_label, sliver_target_label, average="macro")
acc_sliver_target = accuracy_score(sliver_target_label, sliver_predic_pro)
acc_sliver_gt = accuracy_score(sliver_gt_label, sliver_predic_pro)
print("\n" + statues + "\tSliver " + "\t Acc {}, \tPrecision {}, \tRecall {} \n acc_target {}, acc_gt {}"
.format(acc, precision, recall, acc_sliver_target, acc_sliver_gt))
# print("\n" + statues + "\tSliver " + "\t Acc {}, \tPrecision {}, \tRecall {}".format(acc, precision, recall))
# print("\n" + statues + "\tSliver " + "\t Acc {}, \tPrecision {}, \tRecall {}".format(acc, precision, recall))
def train_hydra_base(gold_iter, sliver_iter, val_iter, model, args, alpha, statues=""):
if args.cuda:
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
model.train()
steps = 0
best_acc = 0
last_step = 0
model.train()
gold_batch_count = len(gold_iter)
for epoch in range(1, args.epochs + 1):
sliver_gt_all = []
sliver_labels_all = []
sliver_predic_pro = []
for batch_idx, batch in enumerate(chain(gold_iter, sliver_iter)):
model.train()
feature, target = batch.text, batch.label
feature = torch.transpose(feature, 1, 0).contiguous()
target = target - 1
if args.cuda:
feature, target = feature.cuda(), target.cuda()
optimizer.zero_grad()
if batch_idx >= gold_batch_count:
# switch to the sliver mode
logit = model.forward_sliver(feature)
else:
logit = model.forward_gold(feature)
logit = torch.log(logit)
loss = F.nll_loss(logit, target)
if batch_idx >= gold_batch_count:
loss = loss * alpha
loss.backward()
optimizer.step()
steps += 1
if steps % args.log_interval == 0:
corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
accuracy = 100.0 * corrects / batch.batch_size
# sys.stdout.write(
# '\r{} Batch[{}] - loss: {:.6f} acc: {:.4f}%({}/{})'.format(statues, steps,
# loss.data,
# accuracy,
# corrects,
# batch.batch_size))
if steps % args.test_interval == 0:
dev_acc = eval(val_iter, model, args)
dev_acc = dev_acc[0]
if dev_acc > best_acc:
best_acc = dev_acc
last_step = steps
# save the best model
if args.save_best:
save(model, args.save_dir, 'best_{}'.format(statues), 0)
else:
if steps - last_step >= args.early_stop:
print('early stop by {} steps.'.format(args.early_stop))
elif steps % args.save_interval == 0:
save(model, args.save_dir, 'snapshot', steps)
sliver_labels_all = list(chain.from_iterable(sliver_labels_all))
sliver_gt_all = list(chain.from_iterable(sliver_gt_all))
sliver_predic_pro = list(chain.from_iterable(sliver_predic_pro))
acc = accuracy_score(sliver_gt_all,sliver_labels_all)
recall = recall_score(sliver_gt_all, sliver_labels_all, average="macro")
precesion = precision_score(sliver_gt_all, sliver_labels_all, average='macro')
acc_gt = accuracy_score(sliver_gt_all, sliver_predic_pro)
acc_target = accuracy_score(sliver_labels_all, sliver_predic_pro)
print("\n\n[Correction Label Result] acc: {}, recall: {}, precesion: {}, \n acc_target: {}, acc_gt: {}"
.format(acc, recall, precesion, acc_target, acc_gt))
def train_with_glc_label(gold_iter, sliver_iter, val_iter, glc_model, train_model, args, alpha, statues=""):
if args.cuda:
glc_model.cuda()
train_model.cuda()
optimizer = torch.optim.Adam(train_model.parameters(), lr=args.lr)
steps = 0
best_acc = 0
last_step = 0
train_model.train()
glc_model.eval()
gold_batch_count = len(gold_iter)
for epoch in range(1, args.epochs + 1):
sliver_gt_all = []
sliver_labels_all = []
sliver_predic_pro = []
for batch_idx, batch in enumerate(chain(gold_iter, sliver_iter)):
feature, target = batch.text, batch.label
feature = torch.transpose(feature, 1, 0).contiguous()
train_model.train()
target = target - 1
if args.cuda:
feature, target = feature.cuda(), target.cuda()
optimizer.zero_grad()
if batch_idx >= gold_batch_count:
# switch to the sliver mode
sliver_logit = glc_model(feature)
target = torch.argmax(sliver_logit, dim=-1)
logit = train_model.forward_sliver(feature)
sliver_predic_pro.append(torch.argmax(logit, dim=-1).cpu().numpy().tolist())
sliver_labels_all.append(target.cpu().numpy().tolist())
sliver_gt_target = batch.gt_label - 1
sliver_gt_all.append(sliver_gt_target.numpy().tolist())
else:
logit = train_model.forward_gold(feature)
logit = torch.log(logit)
loss = F.nll_loss(logit, target)
if batch_idx >= gold_batch_count:
loss = loss * alpha
loss.backward()
optimizer.step()
steps += 1
if steps % args.log_interval == 0:
corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
accuracy = 100.0 * corrects / batch.batch_size
# sys.stdout.write(
# '\r{} Batch[{}] - loss: {:.6f} acc: {:.4f}%({}/{})'.format(statues, steps,
# loss.data,
# accuracy,
# corrects,
# batch.batch_size))
if steps % args.test_interval == 0:
dev_acc = eval(val_iter, train_model, args)
dev_acc = dev_acc[0]
if dev_acc > best_acc:
best_acc = dev_acc
last_step = steps
# save the best model
if args.save_best:
save(train_model, args.save_dir, 'best_{}'.format(statues), 0)
else:
if steps - last_step >= args.early_stop:
print('early stop by {} steps.'.format(args.early_stop))
elif steps % args.save_interval == 0:
save(train_model, args.save_dir, 'snapshot', steps)
sliver_labels_all = list(chain.from_iterable(sliver_labels_all))
sliver_gt_all = list(chain.from_iterable(sliver_gt_all))
sliver_predic_pro = list(chain.from_iterable(sliver_predic_pro))
acc = accuracy_score(sliver_gt_all,sliver_labels_all)
recall = recall_score(sliver_gt_all, sliver_labels_all, average="macro")
precesion = precision_score(sliver_gt_all, sliver_labels_all, average='macro')
acc_gt = accuracy_score(sliver_gt_all, sliver_predic_pro)
acc_target = accuracy_score(sliver_labels_all, sliver_predic_pro)
print("\n\n[Correction Label Result] acc: {}, recall: {}, precesion: {}, \n acc_target: {}, acc_gt: {}"
.format(acc, recall, precesion, acc_target, acc_gt))
def estimate_c(model, gold_iter, args):
# load the best pesudo-clf model
model.eval()
gold_pro_all = []
gold_label_all = []
with torch.no_grad():
for batch in gold_iter:
gold_feature, gold_target = batch.text, batch.label
gold_target = gold_target - 1
gold_feature = torch.transpose(gold_feature, 1, 0).contiguous()
if args.cuda:
gold_feature, gold_target = gold_feature.cuda(), gold_target.cuda()
gold_pro_all.append(model(gold_feature))
gold_label_all.append(gold_target)
gold_pro_all = torch.cat(gold_pro_all, dim=0)
gold_label_all = torch.cat(gold_label_all, dim=0)
C_hat = get_correction_matrix(gold_pro=gold_pro_all, gold_label=gold_label_all, method=args.gold_method)
return C_hat
def eval(data_iter, model, args):
model.eval()
corrects, avg_loss = 0, 0
prediction = []
labels = []
for batch in data_iter:
feature, target = batch.text, batch.label
feature = torch.transpose(feature, 1, 0).contiguous()
target = target - 1
# feature.data.t_(), target.data.sub_(1) # batch first, index align
# if args.cuda:
if args.cuda:
feature, target = feature.cuda(), target.cuda()
logit = model(feature)
loss = F.cross_entropy(logit, target, size_average=False)
avg_loss += loss.data
prediction += torch.argmax(logit, 1).cpu().numpy().tolist()
labels.extend(target.cpu().numpy().tolist())
accuracy = accuracy_score(y_true=labels, y_pred=prediction)
size = len(data_iter.dataset.examples)
avg_loss /= size
corrects = accuracy * size
recall = recall_score(y_true=labels, y_pred=prediction, average="macro")
precision = precision_score(y_true=labels, y_pred=prediction, average="macro")
f1 = f1_score(y_true=labels, y_pred=prediction, average="macro")
print('\nEvaluation - loss: {:.6f} recall: {:.4f}, precision: {:.4f} acc: {:.4f}%({}/{}) \n'.format(avg_loss,
recall,
precision,
accuracy,
corrects,
size))
return accuracy, recall, precision, f1
def predict(text, model, text_field, label_feild, cuda_flag):
assert isinstance(text, str)
model.eval()
# text = text_field.tokenize(text)
text = text_field.preprocess(text)
text = [[text_field.vocab.stoi[x] for x in text]]
x = torch.tensor(text)
x = autograd.Variable(x)
if cuda_flag:
x = x.cuda()
print(x)
output = model(x)
_, predicted = torch.max(output, 1)
#return label_feild.vocab.itos[predicted.data[0][0]+1]
return label_feild.vocab.itos[predicted.data[0]+1]
def save(model, save_dir, save_prefix, steps):
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
save_prefix = os.path.join(save_dir, save_prefix)
save_path = '{}_steps_{}.pt'.format(save_prefix, steps)
torch.save(model.state_dict(), save_path)
def load_model(model, save_dir, save_prefix, steps):
save_prefix = os.path.join(save_dir, save_prefix)
save_path = '{}_steps_{}.pt'.format(save_prefix, steps)
model.load_state_dict(torch.load(save_path))
return model
def train_in_one(args, gold_iter, sliver_iter, val_iter, test_iter, gold_frac, alpha):
torch.manual_seed(123)
fout = open(os.path.join(args.save_dir, "a.result"), "a")
fout.write("-" * 90 + "Gold Ratio: {} Alpha: {}".format(gold_frac, alpha) + "-" * 90 + "\n")
# train only on the weak data
cnn = model.BiLSTM(args)
train(gold_iter=gold_iter, sliver_iter=sliver_iter, val_iter=val_iter, model=cnn, args=args, statues="only_weak")
test_model = model.BiLSTM(args)
test_model = load_model(test_model, args.save_dir, 'best_{}'.format("only_weak"), 0)
if args.cuda:
test_model.cuda()
accuracy, recall, precision, f1 = eval(test_iter, test_model, args)
fout.write("Weak Acc: {}, recall: {}, precision: {}, f1: {}".format(str(accuracy), str(recall), str(precision),
str(f1)))
fout.write("\n")
del cnn
del test_model
# train only on the weak and gold data
cnn = model.BiLSTM(args)
train(gold_iter=gold_iter, sliver_iter=sliver_iter, val_iter=val_iter, model=cnn, args=args, statues="weak_gold")
test_model = model.BiLSTM(args)
test_model = load_model(test_model, args.save_dir, 'best_{}'.format("weak_gold"), 0)
if args.cuda:
test_model.cuda()
accuracy, recall, precision, f1 = eval(test_iter, test_model, args)
fout.write("WeakGold Acc: {}, recall: {}, precision: {}, f1: {}".format(str(accuracy), str(recall), str(precision),
str(f1)))
fout.write("\n")
del cnn
del test_model
# train only on the golden data
cnn = model.BiLSTM(args)
train(gold_iter=gold_iter, sliver_iter=gold_iter, val_iter=val_iter, model=cnn, args=args, statues="test")
test_model = model.BiLSTM(args)
test_model = load_model(test_model, args.save_dir, 'best_{}'.format("test"), 0)
if args.cuda:
test_model.cuda()
accuracy, recall, precision, f1 = eval(test_iter, test_model, args)
fout.write("Only Gold Acc: {}, recall: {}, precision: {}, f1: {}".format(str(accuracy), str(recall), str(precision), str(f1)))
fout.write("\n")
del cnn
del test_model
# hydra-base model
cnn = model.BiLSTM(args)
train_hydra_base(gold_iter=gold_iter, sliver_iter=sliver_iter, val_iter=val_iter, model=cnn, args=args, statues="hydra_base", alpha=alpha)
test_model = model.BiLSTM(args)
test_model = load_model(test_model, args.save_dir, 'best_{}'.format("hydra_base"), 0)
if args.cuda:
test_model.cuda()
accuracy, recall, precision, f1 = eval(test_iter, test_model, args)
fout.write("HydraBase Acc: {}, recall: {}, precision: {}, f1: {}".format(str(accuracy), str(recall), str(precision),
str(f1)))
fout.write("\n")
del cnn
del test_model
# //////////////////////// train for estimation ////////////////////////
cnn = model.BiLSTM(args)
if args.cuda:
cnn = cnn.cuda()
print("\n" + "*" * 40 + "Training in Base Estimation" + "*" * 40)
train(gold_iter=sliver_iter, sliver_iter=sliver_iter, val_iter=val_iter, model=cnn, args=args, statues="esti")
print("*" * 40 + "Finish in Base Estimation" + "*" * 40)
del cnn
# # //////////////////////// estimate C ////////////////////////
cnn = model.BiLSTM(args)
cnn = load_model(cnn, args.save_dir, 'best_{}'.format("esti"), 0)
if args.cuda:
cnn.cuda()
C_hat = estimate_c(cnn, gold_iter, args)
del cnn
# //////////////////////// retrain with correction ////////////////////////
cnn = model.BiLSTM(args)
print("\n" + "*"*40 + "Training in Correction" + "*"*40)
if args.cuda:
cnn = cnn.cuda()
train(gold_iter=gold_iter, sliver_iter=sliver_iter, val_iter=val_iter, model=cnn, args=args, statues="glc", C_hat=C_hat)
# eval(data_iter=test_iter, model=cnn, args=args)
# del cnn
test_model = model.BiLSTM(args)
test_model = load_model(test_model, args.save_dir, 'best_{}'.format("glc"), 0)
if args.cuda:
test_model.cuda()
accuracy, recall, precision, f1 = eval(test_iter, test_model, args)
fout.write("GLC Result: {}, recall: {}, precision: {}, f1: {}".format(str(accuracy), str(recall), str(precision), str(f1)))
fout.write("\n")
print("\n" + "*"*40 + "Finish in Correction" + "*"*40)
# //////////////////////// Using GLC labels for training ////////////////////////
# correction labels for training the last classifier
print("\n" + "*"*40 + "Training with GLC label" + "*"*40)
glc_model = model.BiLSTM(args)
glc_model = load_model(glc_model, args.save_dir, 'best_{}'.format("glc"), 0)
final_clf_model = model.BiLSTM(args)
# final_clf_model = load_model(final_clf_model, args.save_dir, 'best_{}'.format("glc"), 0 )
train_with_glc_label(gold_iter=gold_iter, sliver_iter=sliver_iter, val_iter=val_iter, glc_model=glc_model,
train_model=final_clf_model, args=args, statues="final", alpha=alpha)
del glc_model
del final_clf_model
print("\n" + "*" * 40 + "Finish in GLC" + "*" * 40)
# //////////////////////// Test the model on Test dataset ////////////////////////
print("\n" + "*" * 40 + "Evaluating in Test data" + "*" * 40)
test_model = model.BiLSTM(args)
test_model = load_model(test_model, args.save_dir, 'best_{}'.format("final"), 0)
if args.cuda:
test_model.cuda()
accuracy, recall, precision, f1 = eval(test_iter, test_model, args)
fout.write("Hydra Acc: {}, recall: {}, precision: {}, f1: {}".format(str(accuracy), str(recall), str(precision), str(f1)))
fout.write("\n")
fout.write("-" * 90 + "END THIS" + "-" * 90 + "\n\n\n")
fout.close()